Coverage for openxps/utils.py: 96%
72 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
1"""
2.. module:: openxps.utils
3 :platform: Linux, MacOS, Windows
4 :synopsis: Utility functions for OpenXPS.
6.. classauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import functools
11import inspect
12import re
13import typing as t
15import cvpack
16from cvpack.serialization import Serializable
17from cvpack.units import Quantity, Unit
18from openmm import unit as mmunit
20#: The separator used to split checkpoint and XML strings into physical and extension
21#: parts.
22STRING_SEPARATOR = "\f\f"
23BINARY_SEPARATOR = b"::SdXN3dO::"
25LEPTON_FUNCTIONS = frozenset(
26 {
27 "sqrt",
28 "exp",
29 "log",
30 "sin",
31 "cos",
32 "sec",
33 "csc",
34 "tan",
35 "cot",
36 "asin",
37 "acos",
38 "atan",
39 "atan2",
40 "sinh",
41 "cosh",
42 "tanh",
43 "erf",
44 "erfc",
45 "min",
46 "max",
47 "abs",
48 "floor",
49 "ceil",
50 "step",
51 "delta",
52 "select",
53 }
54)
57def preprocess_args(func: t.Callable) -> t.Callable:
58 """
59 A decorator that converts instances of unserializable classes to their
60 serializable counterparts.
62 Parameters
63 ----------
64 func
65 The function to be decorated.
67 Returns
68 -------
69 The decorated function.
71 Example
72 -------
73 >>> from openxps.utils import preprocess_args
74 >>> from cvpack import units
75 >>> from openmm import unit as mmunit
76 >>> @preprocess_args
77 ... def function(data):
78 ... return data
79 >>> assert isinstance(function(mmunit.angstrom), units.Unit)
80 >>> assert isinstance(function(5 * mmunit.angstrom), units.Quantity)
81 >>> seq = [mmunit.angstrom, mmunit.nanometer]
82 >>> assert isinstance(function(seq), list)
83 >>> assert all(isinstance(item, units.Unit) for item in function(seq))
84 >>> dct = {"length": 3 * mmunit.angstrom, "time": 2 * mmunit.picosecond}
85 >>> assert isinstance(function(dct), dict)
86 >>> assert all(isinstance(item, units.Quantity) for item in function(dct).values())
87 """
88 signature = inspect.signature(func)
90 def convert(data: t.Any) -> t.Any:
91 if isinstance(data, mmunit.Quantity):
92 return Quantity(data)
93 if isinstance(data, mmunit.Unit):
94 return Unit(data)
95 if isinstance(data, t.Sequence) and not isinstance(data, str):
96 return type(data)(map(convert, data))
97 if isinstance(data, dict):
98 return type(data)((key, convert(value)) for key, value in data.items())
99 return data
101 @functools.wraps(func)
102 def wrapper(*args, **kwargs):
103 bound = signature.bind(*args, **kwargs)
104 for name, data in bound.arguments.items():
105 bound.arguments[name] = convert(data)
106 return func(*bound.args, **bound.kwargs)
108 return wrapper
111class Function(Serializable):
112 """A function of dynamical variables and global parameters.
114 Parameters
115 ----------
116 name
117 The name of the function.
118 expression
119 The expression of the function.
120 **given_parameters
121 The given parameters of the function.
123 Examples
124 --------
125 >>> from copy import copy
126 >>> from openxps.utils import Function
127 >>> f1 = Function("f", "a*x^2", a=1.0)
128 >>> f1
129 Function("f(x)=a*x^2", a=1.0)
130 >>> f2 = Function("g", "exp(-x)*cos(2*pi*y)", pi=3.14159)
131 >>> f2
132 Function("g(x, y)=exp(-x)*cos(2*pi*y)", pi=3.14159)
133 >>> copy(f1)
134 Function("f(x)=a*x^2", a=1.0)
135 """
137 @preprocess_args
138 def __init__(
139 self, name: str, expression: str, **given_parameters: mmunit.Quantity
140 ) -> None:
141 self._name = name
142 self._expression = expression
143 variables, parameters = self._parseDependencies(given_parameters)
144 self._variables = variables
145 self._parameters = {name: given_parameters[name] for name in parameters}
147 def __repr__(self) -> str:
148 variables = ", ".join(sorted(self._variables))
149 parameters = ", ".join(f"{k}={v}" for k, v in sorted(self._parameters.items()))
150 return f'Function("{self._name}({variables})={self._expression}", {parameters})'
152 def __copy__(self) -> "Function":
153 new = Function.__new__(Function)
154 new.__setstate__(self.__getstate__())
155 return new
157 def __getstate__(self) -> dict[str, t.Any]:
158 return {
159 "name": self._name,
160 "expression": self._expression,
161 "variables": self._variables,
162 "parameters": self._parameters,
163 }
165 def __setstate__(self, state: dict[str, t.Any]) -> None:
166 self._name = state["name"]
167 self._expression = state["expression"]
168 self._variables = state["variables"]
169 self._parameters = state["parameters"]
171 def _parseDependencies(
172 self,
173 given_parameters: dict[str, mmunit.Quantity],
174 ) -> tuple[set[str], set[str]]:
175 given_parameters = set(given_parameters.keys())
176 dependencies = (
177 set(re.findall(r"\b[A-Za-z_][A-Za-z0-9_]*\b", self._expression))
178 - LEPTON_FUNCTIONS
179 )
180 variables = dependencies - given_parameters
181 parameters = dependencies & given_parameters
182 if missing := dependencies - (variables | parameters):
183 raise ValueError(
184 f"Function {self._name} has unknown dependencies: {missing}"
185 )
186 return variables, parameters
188 def getName(self) -> str:
189 """Return the name of the function."""
190 return self._name
192 def getExpression(self) -> str:
193 """Return the expression string of the function."""
194 return self._expression
196 def getVariables(self) -> set[str]:
197 """Return the set of variable names in the function."""
198 return self._variables
200 def getParameters(self) -> dict[str, mmunit.Quantity]:
201 """Return the dictionary of parameter names and their values."""
202 return self._parameters
204 def createCollectiveVariable(
205 self,
206 all_variables: list[str],
207 ) -> cvpack.AtomicFunction:
208 """
209 Create a collective variable from the function.
211 Parameters
212 ----------
213 all_variables
214 The list of all variables in the system.
216 Returns
217 -------
218 cvpack.AtomicFunction
219 The collective variable object.
221 Examples
222 --------
223 >>> from openxps.utils import Function
224 >>> import openmm as mm
225 >>> from math import exp, cos, pi
226 >>> fn = Function("g", "exp(-x)*cos(2*pi*y)", pi=3.14159)
227 >>> cv = fn.createCollectiveVariable(["x", "y"])
228 >>> system = mm.System()
229 >>> for _ in range(2):
230 ... _ = system.addParticle(1.0)
231 >>> cv.addToSystem(system)
232 >>> context = mm.Context(system, mm.VerletIntegrator(0.0))
233 >>> x, y = 2, 1
234 >>> exp(-x)*cos(2*pi*y)
235 0.13533528...
236 >>> context.setPositions([mm.Vec3(x, 0.0, 0.0), mm.Vec3(y, 0.0, 0.0)])
237 >>> context.getState(getEnergy=True).getPotentialEnergy()
238 0.13533528... kJ/mol
239 """
240 return cvpack.AtomicFunction(
241 function=";".join(
242 [self._expression]
243 + [f"{variable}=x{i + 1}" for i, variable in enumerate(self._variables)]
244 ),
245 unit=mmunit.kilojoule_per_mole,
246 groups=[[all_variables.index(variable) for variable in self._variables]],
247 name=self._name,
248 **self._parameters,
249 )
252Function.registerTag("!openxps.utils.Function")