Coverage for openxps/couplings/base.py: 97%
141 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
1"""
2Base class for couplings.
4.. module:: openxps.couplings.base
5 :platform: Linux, MacOS, Windows
6 :synopsis: Base class for couplings between physical and extended phase-space systems
8.. classauthor:: Charlles Abreu <craabreu@gmail.com>
10"""
12import typing as t
14import cvpack
15import openmm as mm
16from cvpack.serialization import Serializable
17from openmm import XmlSerializer
18from openmm import _openmm as mmswig
19from openmm import unit as mmunit
21from ..dynamical_variable import DynamicalVariable
24class Coupling(Serializable):
25 """Abstract base class for couplings between physical and extension systems.
27 A coupling connects the physical system's coordinates to the extension system's
28 dynamical variables, enabling enhanced sampling simulations.
30 Subclasses must implement the :meth:`addToPhysicalSystem` and
31 :meth:`addToExtensionSystem` methods.
33 Parameters
34 ----------
35 forces
36 A sequence of :OpenMM:`Force` objects.
37 dynamical_variables
38 A sequence of :class:`DynamicalVariable` objects.
39 """
41 def __init__(
42 self,
43 forces: t.Iterable[mm.Force],
44 dynamical_variables: t.Sequence[DynamicalVariable],
45 ) -> None:
46 self._forces = list(forces)
47 self._dynamical_variables = [dv.in_md_units() for dv in dynamical_variables]
48 unique_names = {dv.name for dv in self._dynamical_variables}
49 if len(unique_names) != len(self._dynamical_variables):
50 raise ValueError("The dynamical variables must have unique names.")
51 self._dv_indices = {
52 dv.name: index for index, dv in enumerate(self._dynamical_variables)
53 }
54 self._flipped_force = None
55 self._checkGlobalParameters()
57 def __add__(self, other: "Coupling") -> "CouplingSum":
58 return CouplingSum([self, other])
60 def __copy__(self) -> "Coupling":
61 new = self.__class__.__new__(self.__class__)
62 new.__setstate__(self.__getstate__())
63 return new
65 def __getstate__(self) -> dict[str, t.Any]:
66 return {
67 "forces": self._forces,
68 "dynamical_variables": self._dynamical_variables,
69 "dv_indices": self._dv_indices,
70 }
72 def __setstate__(self, state: dict[str, t.Any]) -> None:
73 self._forces = state["forces"]
74 self._dynamical_variables = state["dynamical_variables"]
75 self._dv_indices = state["dv_indices"]
76 self._flipped_force = None
78 def _checkGlobalParameters(self) -> dict[str, mmunit.Quantity]:
79 parameters = {}
80 for force in self._forces:
81 force_parameters = {}
82 for index in range(force.getNumGlobalParameters()):
83 name = force.getGlobalParameterName(index)
84 force_parameters[name] = force.getGlobalParameterDefaultValue(index)
85 for key, value in force_parameters.items():
86 if key in parameters and parameters[key] != value:
87 raise ValueError(
88 f"Parameter {key} has conflicting default values in "
89 f"coupling: {parameters[key]} != {value}"
90 )
91 parameters[key] = value
93 def _createFlippedForce(self) -> mm.Force | None:
94 return None
96 @staticmethod
97 def _addForceToSystem(force: mm.Force, system: mm.System) -> None:
98 if isinstance(force, cvpack.CollectiveVariable):
99 force.addToSystem(system)
100 else:
101 system.addForce(force)
103 def _updateDynamicalVariableIndices(
104 self, dynamical_variables: t.Sequence[DynamicalVariable]
105 ) -> None:
106 """Update the indices of the dynamical variables associated with this coupling.
108 Parameters
109 ----------
110 dynamical_variables
111 All the dynamical variables in the system, regardless of whether they are
112 associated with this coupling or not.
113 """
114 for index, dv in enumerate(dynamical_variables):
115 if dv.name in self._dv_indices:
116 self._dv_indices[dv.name] = index
118 def getForces(self) -> list[mm.Force]:
119 """Get the list of OpenMM Force objects associated with this coupling.
121 Returns
122 -------
123 list[openmm.Force]
124 A list of Force objects contained within this coupling.
125 """
126 return self._forces
128 def getDynamicalVariables(self) -> t.Sequence[DynamicalVariable]:
129 """Get the dynamical variables associated with this coupling.
131 Returns
132 -------
133 list[DynamicalVariable]
134 A list of DynamicalVariable objects contained within this coupling.
135 """
136 return self._dynamical_variables
138 def getForce(self, index: int) -> mm.Force:
139 """Retrieve a single OpenMM Force object from this coupling.
141 Parameters
142 ----------
143 index
144 The index of the Force object to retrieve.
146 Returns
147 -------
148 openmm.Force
149 The Force object at the specified index.
151 Raises
152 ------
153 IndexError
154 If the index is out of range.
155 """
156 return self._forces[index]
158 def getDynamicalVariable(self, index: int) -> DynamicalVariable:
159 """Retrieve a single dynamical variable from this coupling.
161 Parameters
162 ----------
163 index
164 The index of the DynamicalVariable object to retrieve.
166 Returns
167 -------
168 DynamicalVariable
169 The DynamicalVariable object at the specified index.
171 Raises
172 ------
173 IndexError
174 If the index is out of range.
175 """
176 return self._dynamical_variables[index]
178 def getProtectedParameters(self) -> set[str]:
179 """Get parameters of the physical context that should not be manually modified.
181 Returns
182 -------
183 set[str]
184 The protected parameters.
185 """
186 if self._flipped_force is None:
187 raise ValueError("This coupling has not been added to an extension system.")
188 return {
189 self._flipped_force.getCollectiveVariableName(index)
190 for index in range(self._flipped_force.getNumCollectiveVariables())
191 }
193 def addToPhysicalSystem(self, system: mm.System) -> None:
194 """Add this coupling to an OpenMM system.
196 Parameters
197 ----------
198 system
199 The system to which the coupling should be added.
201 Raises
202 ------
203 NotImplementedError
204 This method must be implemented by subclasses.
206 """
207 for force in self._forces:
208 self._addForceToSystem(force, system)
210 def addToExtensionSystem(self, system: mm.System) -> None:
211 """
212 Add the flipped version of this coupling to the extension system.
214 The flipped force replaces dynamical variable parameters with collective
215 variables that represent the dynamical variables as particles. Physical
216 collective variables become parameters set to zero.
218 Parameters
219 ----------
220 extension_system
221 The extension system to which the flipped coupling should be added.
223 Examples
224 --------
225 >>> import cvpack
226 >>> import openxps as xps
227 >>> from openmm import unit
228 >>> from math import pi
229 >>> import openmm as mm
230 >>> phi = cvpack.Torsion(6, 8, 14, 16, name="phi")
231 >>> phi0 = xps.DynamicalVariable(
232 ... "phi0",
233 ... unit.radian,
234 ... 3 * unit.dalton * (unit.nanometer / unit.radian)**2,
235 ... xps.CircularBounds()
236 ... )
237 >>> coupling = xps.CollectiveVariableCoupling(
238 ... "0.5*kappa*min(delta,{2*pi}-delta)^2; delta=abs(phi-phi0)",
239 ... [phi],
240 ... [phi0],
241 ... kappa=1000 * unit.kilojoules_per_mole / unit.radian**2,
242 ... )
243 >>> extension_system = mm.System()
244 >>> extension_system.addParticle(phi0.mass / phi0.mass.unit)
245 0
246 >>> coupling.addToExtensionSystem(extension_system)
247 >>> extension_system.getNumForces()
248 1
249 """
250 self._flipped_force = self._createFlippedForce()
251 self._addForceToSystem(self._flipped_force, system)
253 def updatePhysicalContext(
254 self,
255 physical_context: mm.Context,
256 extension_context: mm.Context,
257 ) -> None:
258 """Update the physical context with the current extension parameters.
260 Parameters
261 ----------
262 physical_context
263 The physical context to update with the extension parameters.
264 extension_context
265 The extension context to get the extension parameters from.
267 """
268 collective_variables = mmswig.CustomCVForce_getCollectiveVariableValues(
269 self._flipped_force, extension_context
270 )
271 for index, value in enumerate(collective_variables):
272 mmswig.Context_setParameter(
273 physical_context,
274 self._flipped_force.getCollectiveVariableName(index),
275 value,
276 )
278 def updateExtensionContext(
279 self,
280 physical_context: mm.Context,
281 extension_context: mm.Context,
282 ) -> None:
283 """Update the extension context with the current physical parameters.
285 Parameters
286 ----------
287 physical_context
288 The physical context to get the physical parameters from.
289 extension_context
290 The extension context to update with the physical parameters.
292 """
293 raise NotImplementedError("Subclasses must implement this method.")
295 def getCollectiveVariableValues(
296 self, physical_context: mm.Context
297 ) -> dict[str, float]:
298 """Get the values of the collective variables.
300 Parameters
301 ----------
302 physical_context
303 The physical context to get the collective variable values from.
304 """
305 collective_variables = {}
306 for force in self._forces:
307 if isinstance(force, mm.CustomCVForce):
308 cv_values = force.getCollectiveVariableValues(physical_context)
309 for index, value in enumerate(cv_values):
310 collective_variables[
311 mmswig.CustomCVForce_getCollectiveVariableName(force, index)
312 ] = value
313 return collective_variables
316class CouplingSum(Coupling):
317 """A sum of couplings.
319 Parameters
320 ----------
321 couplings
322 The couplings to be added.
324 Examples
325 --------
326 >>> from copy import copy
327 >>> import cvpack
328 >>> import openxps as xps
329 >>> from openmm import unit
330 >>> phi = cvpack.Torsion(6, 8, 14, 16, name="phi")
331 >>> psi = cvpack.Torsion(4, 6, 8, 14, name="psi")
332 >>> dv_mass = 3 * unit.dalton * (unit.nanometer / unit.radian)**2
333 >>> phi_s = xps.DynamicalVariable(
334 ... "phi_s", unit.radian, dv_mass, xps.CircularBounds()
335 ... )
336 >>> psi_s = xps.DynamicalVariable(
337 ... "psi_s", unit.radian, dv_mass, xps.CircularBounds()
338 ... )
339 >>> coupling = xps.HarmonicCoupling(
340 ... phi, phi_s, 1000 * unit.kilojoule_per_mole / unit.radian**2
341 ... ) + xps.HarmonicCoupling(
342 ... psi, psi_s, 500 * unit.kilojoule_per_mole / unit.radian**2
343 ... )
344 """
346 def __init__(self, couplings: t.Iterable[Coupling]) -> None:
347 self._couplings = []
348 forces = []
349 dv_dict = {}
350 for coupling in couplings:
351 if isinstance(coupling, CouplingSum):
352 self._couplings.extend(coupling.getCouplings())
353 else:
354 self._couplings.append(coupling)
355 forces.extend(coupling.getForces())
356 for dv in coupling.getDynamicalVariables():
357 if dv.name not in dv_dict:
358 dv_dict[dv.name] = dv
359 elif dv_dict[dv.name] != dv:
360 raise ValueError(
361 f'The dynamical variable "{dv.name}" has '
362 "conflicting definitions in the couplings."
363 )
364 super().__init__(forces, sorted(dv_dict.values(), key=lambda dv: dv.name))
365 self._broadcastDynamicalVariableIndices()
366 self._checkCollectiveVariables()
368 def __repr__(self) -> str:
369 return "+".join(f"({repr(coupling)})" for coupling in self._couplings)
371 def __copy__(self) -> "CouplingSum":
372 new = CouplingSum.__new__(CouplingSum)
373 new.__setstate__(self.__getstate__())
374 return new
376 def __getstate__(self) -> dict[str, t.Any]:
377 return {"couplings": self._couplings}
379 def __setstate__(self, state: dict[str, t.Any]) -> None:
380 self.__init__(state["couplings"])
382 def _broadcastDynamicalVariableIndices(self) -> None:
383 for coupling in self._couplings:
384 coupling._updateDynamicalVariableIndices(self._dynamical_variables)
386 def _checkCollectiveVariables(self) -> None:
387 cvs = {}
388 for coupling in self._couplings:
389 for force in coupling.getForces():
390 if isinstance(force, mm.CustomCVForce):
391 for index in range(force.getNumCollectiveVariables()):
392 name = force.getCollectiveVariableName(index)
393 xml_string = XmlSerializer.serialize(
394 force.getCollectiveVariable(index)
395 )
396 if name in cvs and cvs[name] != xml_string:
397 raise ValueError(
398 f'The collective variable "{name}" has conflicting '
399 "definitions in the couplings."
400 )
401 cvs[name] = xml_string
403 def getCouplings(self) -> t.Sequence[Coupling]:
404 """Get the couplings included in the summed coupling."""
405 return self._couplings
407 def getProtectedParameters(self) -> set[str]:
408 return set.union(
409 *[coupling.getProtectedParameters() for coupling in self._couplings]
410 )
412 def addToExtensionSystem(self, system: mm.System) -> None:
413 for coupling in self._couplings:
414 coupling.addToExtensionSystem(system)
416 def updatePhysicalContext(
417 self,
418 physical_context: mm.Context,
419 extension_context: mm.Context,
420 ):
421 for coupling in self._couplings:
422 coupling.updatePhysicalContext(physical_context, extension_context)
424 def updateExtensionContext(
425 self,
426 physical_context: mm.Context,
427 extension_context: mm.Context,
428 ):
429 for coupling in self._couplings:
430 coupling.updateExtensionContext(physical_context, extension_context)
433CouplingSum.registerTag("!openxps.CouplingSum")