Coverage for openxps/integrators/utils.py: 91%
78 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
1"""
2.. module:: openxps.integrators.utils
3 :platform: Linux, MacOS, Windows
4 :synopsis: A mixin for integrators that provides extra functionality.
6.. classauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import textwrap
11import typing as t
13import openmm as mm
14from cvpack.units import Quantity
15from openmm import unit as mmunit
17from openxps.utils import preprocess_args
19BLOCK_START = (6, 7)
20BLOCK_END = 8
22T = t.TypeVar("T")
25class IntegratorMixin:
26 """A mixin for integrators that provides extra functionality."""
28 _forceFirst: bool = False
30 def __repr__(self) -> str:
31 """Return a human-readable version of each integrator step."""
32 readable_lines = []
34 self.getNumPerDofVariables() > 0 and readable_lines.append("Per-dof variables:")
35 per_dof = []
36 for index in range(self.getNumPerDofVariables()):
37 per_dof.append(self.getPerDofVariableName(index))
38 readable_lines.append(" " + ", ".join(per_dof))
40 self.getNumGlobalVariables() > 0 and readable_lines.append("Global variables:")
41 for index in range(self.getNumGlobalVariables()):
42 name = self.getGlobalVariableName(index)
43 value = self.getGlobalVariable(index)
44 readable_lines.append(f" {name} = {value}")
46 readable_lines.append("Computation steps:")
48 step_type_str = [
49 "{target} <- {expr}",
50 "{target} <- {expr}",
51 "{target} <- sum({expr})",
52 "constrain positions",
53 "constrain velocities",
54 "allow forces to update the context state",
55 "if ({expr}):",
56 "while ({expr}):",
57 "end",
58 ]
59 indent_level = 0
60 for step in range(self.getNumComputations()):
61 line = ""
62 step_type, target, expr = self.getComputationStep(step)
63 command = step_type_str[step_type].format(target=target, expr=expr)
64 if step_type == BLOCK_END:
65 indent_level -= 1
66 line += f"{step:4d}: " + " " * indent_level + command
67 if step_type in BLOCK_START:
68 indent_level += 1
69 readable_lines.append(line)
70 return "\n".join(readable_lines)
72 def registerWithSystem(self, system: mm.System, isExtension: bool) -> None:
73 """Register the integrator with the system."""
74 pass
76 @staticmethod
77 def _countDegreesOfFreedom(system: mm.System) -> int:
78 """Count the degrees of freedom in a system.
80 Parameters
81 ----------
82 system
83 The :OpenMM:`System` to count the degrees of freedom of.
85 Returns
86 -------
87 int
88 The number of degrees of freedom in the system.
89 """
90 dof = 0
91 for i in range(system.getNumParticles()):
92 if system.getParticleMass(i) > 0 * mmunit.dalton:
93 dof += 3
94 for i in range(system.getNumConstraints()):
95 p1, p2, _ = system.getConstraintParameters(i)
96 if (system.getParticleMass(p1) > 0 * mmunit.dalton) or (
97 system.getParticleMass(p2) > 0 * mmunit.dalton
98 ):
99 dof -= 1
100 if any(
101 isinstance(system.getForce(i), mm.CMMotionRemover)
102 for i in range(system.getNumForces())
103 ):
104 dof -= 3
105 return dof
107 def isForceFirst(self) -> bool:
108 """Check if the integrator follows a force-first scheme.
110 Returns
111 -------
112 bool
113 True if the integrator is force-first, False otherwise.
114 """
115 return self._forceFirst
118def add_property(property: str, unit: mmunit.Unit) -> t.Callable[[type[T]], type[T]]:
119 """A decorator to add a property to the integrator.
121 The property is added to the integrator as a class attribute with the name
122 `_{snake_case_name}`. Three methods are added to the class:
124 * `_init_{snake_case_name}(value)`: Initializes the property.
125 * `set{PascalCaseName}(value)`: Sets the property and updates the global variables
126 if the integrator has a `_update_global_variables` method.
127 * `get{PascalCaseName}()`: Gets the property.
129 Parameters
130 ----------
131 property
132 The name of the property.
133 unit
134 The unit of the property.
135 """
136 if unit.in_unit_system(mmunit.md_unit_system) != unit:
137 raise ValueError(f"The unit of {property} must be in the MD unit system.")
138 PascalCaseName = "".join(s.capitalize() for s in property.split())
139 snake_case_name = "_".join(s.lower() for s in property.split())
141 def set_value(self, value: t.Union[mmunit.Quantity, float]) -> None:
142 args = [value] if mmunit.is_quantity(value) else [value, unit]
143 setattr(self, f"_{snake_case_name}", Quantity(*args))
145 @preprocess_args
146 def set_and_update(self, value: t.Union[mmunit.Quantity, float]) -> None:
147 set_value(value)
148 if hasattr(self, "_update_global_variables"):
149 self._update_global_variables()
151 set_and_update.__doc__ = textwrap.dedent(
152 f"""\
153 Set the {property}.
155 Parameters
156 ----------
157 value
158 The {property}.
159 """
160 )
162 def get_value(self) -> mmunit.Quantity:
163 return getattr(self, f"_{snake_case_name}")
165 get_value.__doc__ = textwrap.dedent(
166 f"""\
167 Get the {property}.
169 Returns
170 -------
171 openmm.unit.Quantity
172 The {property}.
173 """
174 )
176 def decorator(cls: type[T]) -> type[T]:
177 setattr(cls, f"_{snake_case_name}", None)
178 setattr(cls, f"_init_{snake_case_name}", set_value)
179 setattr(cls, f"set{PascalCaseName}", set_and_update)
180 setattr(cls, f"get{PascalCaseName}", get_value)
181 return cls
183 return decorator