Coverage for openxps/context.py: 99%
81 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
1"""
2.. module:: openxps.context
3 :platform: Linux, MacOS, Windows
4 :synopsis: Context for extended phase-space simulations with OpenMM.
6.. classauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import typing as t
12import openmm as mm
13from openmm import _openmm as mmswig
14from openmm import unit as mmunit
16from .integrator import ExtendedSpaceIntegrator
17from .system import ExtendedSpaceSystem
18from .utils import BINARY_SEPARATOR
21class ExtendedSpaceContext(mm.Context):
22 """An :OpenMM:`Context` object that includes extra dynamical variables (DVs) and
23 allows for extended phase-space (XPS) simulations.
25 Parameters
26 ----------
27 system
28 The :class:`ExtendedSpaceSystem` to be used in the XPS simulation.
29 integrator
30 An :class:`ExtendedSpaceIntegrator` object to be used for advancing the XPS
31 simulation. Available implementations include :class:`LockstepIntegrator` for
32 systems where both integrators use the same step size, and
33 :class:`SplitIntegrator` for systems with different step sizes related by an
34 even integer ratio.
35 platform
36 The :OpenMM:`Platform` to use for calculations.
37 properties
38 A dictionary of values for platform-specific properties.
40 Example
41 -------
42 >>> import openxps as xps
43 >>> from math import pi
44 >>> import openmm
45 >>> import cvpack
46 >>> from openmm import unit
47 >>> from openmmtools import testsystems
48 >>> model = testsystems.AlanineDipeptideVacuum()
49 >>> mass = 3 * unit.dalton*(unit.nanometer/unit.radian)**2
50 >>> phi0 = xps.DynamicalVariable("phi0", unit.radian, mass, xps.CircularBounds())
51 >>> phi = cvpack.Torsion(6, 8, 14, 16, name="phi")
52 >>> kappa = 1000 * unit.kilojoule_per_mole / unit.radian**2
53 >>> harmonic_force = xps.HarmonicCoupling(phi, phi0, kappa)
54 >>> harmonic_force = xps.HarmonicCoupling(phi, phi0, kappa)
55 >>> temp = 300 * unit.kelvin
56 >>> integrator = openmm.LangevinMiddleIntegrator(
57 ... temp, 1 / unit.picosecond, 4 * unit.femtosecond
58 ... )
59 >>> integrator.setRandomNumberSeed(1234)
60 >>> platform = openmm.Platform.getPlatformByName("Reference")
61 >>> height = 2 * unit.kilojoule_per_mole
62 >>> sigma = 18 * unit.degree
63 >>> context = xps.ExtendedSpaceContext(
64 ... xps.ExtendedSpaceSystem(model.system, harmonic_force),
65 ... xps.LockstepIntegrator(integrator),
66 ... platform,
67 ... )
68 >>> context.setPositions(model.positions)
69 >>> context.setVelocitiesToTemperature(temp, 1234)
70 >>> context.setDynamicalVariableValues([180 * unit.degree])
71 >>> context.setDynamicalVariableVelocitiesToTemperature(temp, 1234)
72 >>> context.getIntegrator().step(100)
73 >>> context.getDynamicalVariableValues()
74 (... rad,)
75 >>> state = context.getExtensionContext().getState(getEnergy=True)
76 >>> state.getPotentialEnergy(), state.getKineticEnergy()
77 (... kJ/mol, ... kJ/mol)
78 """
80 def __init__(
81 self,
82 system: ExtendedSpaceSystem,
83 integrator: ExtendedSpaceIntegrator,
84 platform: t.Optional[mm.Platform] = None,
85 properties: t.Optional[dict] = None,
86 ) -> None:
87 self._validate(system, integrator)
88 args = [system, integrator.getPhysicalIntegrator()]
89 if platform is not None:
90 args.append(platform)
91 if properties is not None:
92 args.append(properties)
93 super().__init__(*args)
94 extension_context = mm.Context(
95 system.getExtensionSystem(),
96 integrator.getExtensionIntegrator(),
97 mm.Platform.getPlatformByName("Reference"),
98 )
99 integrator.configure(
100 physical_context=self,
101 extension_context=extension_context,
102 coupling=system.getCoupling(),
103 )
104 self._system = system
105 self._dvs = system.getDynamicalVariables()
106 self._coupling = system.getCoupling()
107 self._integrator = integrator
108 self._extension_context = extension_context
110 def _validate(
111 self,
112 system: ExtendedSpaceSystem,
113 integrator: ExtendedSpaceIntegrator,
114 ) -> None:
115 if not isinstance(system, ExtendedSpaceSystem):
116 raise TypeError("The system must be an instance of ExtendedSpaceSystem.")
117 if not isinstance(integrator, ExtendedSpaceIntegrator):
118 raise TypeError(
119 "The integrator must be an instance of ExtendedSpaceIntegrator."
120 )
122 def getSystem(self) -> ExtendedSpaceSystem:
123 """
124 Get the system included in the extended phase-space context.
126 Returns
127 -------
128 ExtendedSpaceSystem
129 The system.
130 """
131 return self._system
133 def getIntegrator(self) -> ExtendedSpaceIntegrator:
134 """
135 Get the integrator included in the extended phase-space context.
137 Returns
138 -------
139 ExtendedSpaceIntegrator
140 The integrator.
141 """
142 return self._integrator
144 def setParameter(self, name: str, value: mmunit.Quantity) -> None:
145 """
146 Set the value of a global parameter defined by a Force object in the System.
148 Notes
149 -----
150 If the parameter is a dynamical variable, the value will be wrapped to the
151 appropriate boundary condition if necessary.
153 Parameters
154 ----------
155 name
156 The name of the parameter to set.
157 value
158 The value of the parameter.
159 """
160 if name in self._coupling.getProtectedParameters():
161 raise ValueError(
162 f'Cannot manually set the parameter "{name}". This parameter is '
163 "set automatically via setDynamicalVariableValues."
164 )
165 super().setParameter(name, value)
167 def setPositions(self, positions: mmunit.Quantity) -> None:
168 """
169 Sets the positions of all particles in the physical system.
171 Parameters
172 ----------
173 positions
174 The positions for each particle in the system.
175 """
176 super().setPositions(positions)
177 self._coupling.updateExtensionContext(self, self._extension_context)
179 def setDynamicalVariableValues(self, values: t.Iterable[mmunit.Quantity]) -> None:
180 """
181 Set the values of the dynamical variables.
183 Parameters
184 ----------
185 values
186 A sequence of quantities containing the values and units of all extra
187 degrees of freedom.
188 """
189 positions = []
190 for dv, quantity in zip(self._dvs, values):
191 if mmunit.is_quantity(quantity):
192 value = quantity.value_in_unit(dv.unit)
193 else:
194 value = quantity
195 positions.append(mm.Vec3(value, 0, 0))
196 self._extension_context.setPositions(positions)
197 self._coupling.updatePhysicalContext(self, self._extension_context)
199 def getDynamicalVariableValues(self) -> tuple[mmunit.Quantity]:
200 """
201 Get the values of the dynamical variables.
203 Returns
204 -------
205 t.Tuple[mmunit.Quantity]
206 A tuple containing the values of the dynamical variables.
207 """
208 return tuple(self.getParameter(dv.name) * dv.unit for dv in self._dvs)
210 def setDynamicalVariableVelocities(
211 self, velocities: t.Iterable[mmunit.Quantity]
212 ) -> None:
213 """
214 Set the velocities of the dynamical variables.
216 Parameters
217 ----------
218 velocities
219 A dictionary containing the velocities of the dynamical variables.
220 """
221 velocities = list(velocities)
222 for i, dv in enumerate(self._dvs):
223 value = velocities[i]
224 if mmunit.is_quantity(value):
225 value = value.value_in_unit(dv.unit / mmunit.picosecond)
226 velocities[i] = mm.Vec3(value, 0, 0)
227 self._extension_context.setVelocities(velocities)
229 def setDynamicalVariableVelocitiesToTemperature(
230 self, temperature: mmunit.Quantity, seed: t.Optional[int] = None
231 ) -> None:
232 """
233 Set the velocities of the dynamical variables to a temperature.
235 Parameters
236 ----------
237 temperature
238 The temperature to set the velocities to.
239 """
240 args = (temperature,) if seed is None else (temperature, seed)
241 self._extension_context.setVelocitiesToTemperature(*args)
242 state = mmswig.Context_getState(self._extension_context, mm.State.Velocities)
243 velocities = mmswig.State__getVectorAsVec3(state, mm.State.Velocities)
244 self._extension_context.setVelocities([mm.Vec3(v.x, 0, 0) for v in velocities])
246 def getDynamicalVariableVelocities(self) -> tuple[mmunit.Quantity]:
247 """
248 Get the velocities of the dynamical variables.
250 Returns
251 -------
252 t.Tuple[mmunit.Quantity]
253 A tuple containing the velocities of the dynamical variables.
254 """
255 state = mmswig.Context_getState(
256 self._extension_context, mm.State.Positions | mm.State.Velocities
257 )
258 positions = mmswig.State__getVectorAsVec3(state, mm.State.Positions)
259 velocities = mmswig.State__getVectorAsVec3(state, mm.State.Velocities)
260 dv_velocities = []
261 for i, dv in enumerate(self._dvs):
262 _, rate = dv.bounds.wrap(positions[i].x, velocities[i].x)
263 dv_velocities.append(rate * dv.unit / mmunit.picosecond)
264 return tuple(dv_velocities)
266 def getExtensionContext(self) -> mm.Context:
267 """
268 Get a reference to the OpenMM context containing the extension system.
270 Returns
271 -------
272 mm.Context
273 The context containing the extension system.
274 """
275 return self._extension_context
277 def createCheckpoint(self) -> str:
278 r"""Create a checkpoint recording the current state of the Context.
280 This should be treated as an opaque block of binary data. See
281 :meth:`loadCheckpoint` for more details.
283 Returns
284 -------
285 str
286 A string containing the checkpoint data
288 """
289 return (
290 mmswig.Context_createCheckpoint(self)
291 + BINARY_SEPARATOR
292 + mmswig.Context_createCheckpoint(self._extension_context)
293 )
295 def loadCheckpoint(self, checkpoint):
296 r"""Load a checkpoint that was written by :meth:`createCheckpoint`.
298 See :OpenMM:`Context` for more details.
300 Parameters
301 ----------
302 checkpoint
303 The checkpoint data to load.
304 """
305 physical_checkpoint, extension_checkpoint = checkpoint.split(BINARY_SEPARATOR)
306 mmswig.Context_loadCheckpoint(self, physical_checkpoint)
307 mmswig.Context_loadCheckpoint(self._extension_context, extension_checkpoint)