Coverage for openxps/utils.py: 96%

72 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 22:08 +0000

1""" 

2.. module:: openxps.utils 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: Utility functions for OpenXPS. 

5 

6.. classauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import functools 

11import inspect 

12import re 

13import typing as t 

14 

15import cvpack 

16from cvpack.serialization import Serializable 

17from cvpack.units import Quantity, Unit 

18from openmm import unit as mmunit 

19 

20#: The separator used to split checkpoint and XML strings into physical and extension 

21#: parts. 

22STRING_SEPARATOR = "\f\f" 

23BINARY_SEPARATOR = b"::SdXN3dO::" 

24 

25LEPTON_FUNCTIONS = frozenset( 

26 { 

27 "sqrt", 

28 "exp", 

29 "log", 

30 "sin", 

31 "cos", 

32 "sec", 

33 "csc", 

34 "tan", 

35 "cot", 

36 "asin", 

37 "acos", 

38 "atan", 

39 "atan2", 

40 "sinh", 

41 "cosh", 

42 "tanh", 

43 "erf", 

44 "erfc", 

45 "min", 

46 "max", 

47 "abs", 

48 "floor", 

49 "ceil", 

50 "step", 

51 "delta", 

52 "select", 

53 } 

54) 

55 

56 

57def preprocess_args(func: t.Callable) -> t.Callable: 

58 """ 

59 A decorator that converts instances of unserializable classes to their 

60 serializable counterparts. 

61 

62 Parameters 

63 ---------- 

64 func 

65 The function to be decorated. 

66 

67 Returns 

68 ------- 

69 The decorated function. 

70 

71 Example 

72 ------- 

73 >>> from openxps.utils import preprocess_args 

74 >>> from cvpack import units 

75 >>> from openmm import unit as mmunit 

76 >>> @preprocess_args 

77 ... def function(data): 

78 ... return data 

79 >>> assert isinstance(function(mmunit.angstrom), units.Unit) 

80 >>> assert isinstance(function(5 * mmunit.angstrom), units.Quantity) 

81 >>> seq = [mmunit.angstrom, mmunit.nanometer] 

82 >>> assert isinstance(function(seq), list) 

83 >>> assert all(isinstance(item, units.Unit) for item in function(seq)) 

84 >>> dct = {"length": 3 * mmunit.angstrom, "time": 2 * mmunit.picosecond} 

85 >>> assert isinstance(function(dct), dict) 

86 >>> assert all(isinstance(item, units.Quantity) for item in function(dct).values()) 

87 """ 

88 signature = inspect.signature(func) 

89 

90 def convert(data: t.Any) -> t.Any: 

91 if isinstance(data, mmunit.Quantity): 

92 return Quantity(data) 

93 if isinstance(data, mmunit.Unit): 

94 return Unit(data) 

95 if isinstance(data, t.Sequence) and not isinstance(data, str): 

96 return type(data)(map(convert, data)) 

97 if isinstance(data, dict): 

98 return type(data)((key, convert(value)) for key, value in data.items()) 

99 return data 

100 

101 @functools.wraps(func) 

102 def wrapper(*args, **kwargs): 

103 bound = signature.bind(*args, **kwargs) 

104 for name, data in bound.arguments.items(): 

105 bound.arguments[name] = convert(data) 

106 return func(*bound.args, **bound.kwargs) 

107 

108 return wrapper 

109 

110 

111class Function(Serializable): 

112 """A function of dynamical variables and global parameters. 

113 

114 Parameters 

115 ---------- 

116 name 

117 The name of the function. 

118 expression 

119 The expression of the function. 

120 **given_parameters 

121 The given parameters of the function. 

122 

123 Examples 

124 -------- 

125 >>> from copy import copy 

126 >>> from openxps.utils import Function 

127 >>> f1 = Function("f", "a*x^2", a=1.0) 

128 >>> f1 

129 Function("f(x)=a*x^2", a=1.0) 

130 >>> f2 = Function("g", "exp(-x)*cos(2*pi*y)", pi=3.14159) 

131 >>> f2 

132 Function("g(x, y)=exp(-x)*cos(2*pi*y)", pi=3.14159) 

133 >>> copy(f1) 

134 Function("f(x)=a*x^2", a=1.0) 

135 """ 

136 

137 @preprocess_args 

138 def __init__( 

139 self, name: str, expression: str, **given_parameters: mmunit.Quantity 

140 ) -> None: 

141 self._name = name 

142 self._expression = expression 

143 variables, parameters = self._parseDependencies(given_parameters) 

144 self._variables = variables 

145 self._parameters = {name: given_parameters[name] for name in parameters} 

146 

147 def __repr__(self) -> str: 

148 variables = ", ".join(sorted(self._variables)) 

149 parameters = ", ".join(f"{k}={v}" for k, v in sorted(self._parameters.items())) 

150 return f'Function("{self._name}({variables})={self._expression}", {parameters})' 

151 

152 def __copy__(self) -> "Function": 

153 new = Function.__new__(Function) 

154 new.__setstate__(self.__getstate__()) 

155 return new 

156 

157 def __getstate__(self) -> dict[str, t.Any]: 

158 return { 

159 "name": self._name, 

160 "expression": self._expression, 

161 "variables": self._variables, 

162 "parameters": self._parameters, 

163 } 

164 

165 def __setstate__(self, state: dict[str, t.Any]) -> None: 

166 self._name = state["name"] 

167 self._expression = state["expression"] 

168 self._variables = state["variables"] 

169 self._parameters = state["parameters"] 

170 

171 def _parseDependencies( 

172 self, 

173 given_parameters: dict[str, mmunit.Quantity], 

174 ) -> tuple[set[str], set[str]]: 

175 given_parameters = set(given_parameters.keys()) 

176 dependencies = ( 

177 set(re.findall(r"\b[A-Za-z_][A-Za-z0-9_]*\b", self._expression)) 

178 - LEPTON_FUNCTIONS 

179 ) 

180 variables = dependencies - given_parameters 

181 parameters = dependencies & given_parameters 

182 if missing := dependencies - (variables | parameters): 

183 raise ValueError( 

184 f"Function {self._name} has unknown dependencies: {missing}" 

185 ) 

186 return variables, parameters 

187 

188 def getName(self) -> str: 

189 """Return the name of the function.""" 

190 return self._name 

191 

192 def getExpression(self) -> str: 

193 """Return the expression string of the function.""" 

194 return self._expression 

195 

196 def getVariables(self) -> set[str]: 

197 """Return the set of variable names in the function.""" 

198 return self._variables 

199 

200 def getParameters(self) -> dict[str, mmunit.Quantity]: 

201 """Return the dictionary of parameter names and their values.""" 

202 return self._parameters 

203 

204 def createCollectiveVariable( 

205 self, 

206 all_variables: list[str], 

207 ) -> cvpack.AtomicFunction: 

208 """ 

209 Create a collective variable from the function. 

210 

211 Parameters 

212 ---------- 

213 all_variables 

214 The list of all variables in the system. 

215 

216 Returns 

217 ------- 

218 cvpack.AtomicFunction 

219 The collective variable object. 

220 

221 Examples 

222 -------- 

223 >>> from openxps.utils import Function 

224 >>> import openmm as mm 

225 >>> from math import exp, cos, pi 

226 >>> fn = Function("g", "exp(-x)*cos(2*pi*y)", pi=3.14159) 

227 >>> cv = fn.createCollectiveVariable(["x", "y"]) 

228 >>> system = mm.System() 

229 >>> for _ in range(2): 

230 ... _ = system.addParticle(1.0) 

231 >>> cv.addToSystem(system) 

232 >>> context = mm.Context(system, mm.VerletIntegrator(0.0)) 

233 >>> x, y = 2, 1 

234 >>> exp(-x)*cos(2*pi*y) 

235 0.13533528... 

236 >>> context.setPositions([mm.Vec3(x, 0.0, 0.0), mm.Vec3(y, 0.0, 0.0)]) 

237 >>> context.getState(getEnergy=True).getPotentialEnergy() 

238 0.13533528... kJ/mol 

239 """ 

240 return cvpack.AtomicFunction( 

241 function=";".join( 

242 [self._expression] 

243 + [f"{variable}=x{i + 1}" for i, variable in enumerate(self._variables)] 

244 ), 

245 unit=mmunit.kilojoule_per_mole, 

246 groups=[[all_variables.index(variable) for variable in self._variables]], 

247 name=self._name, 

248 **self._parameters, 

249 ) 

250 

251 

252Function.registerTag("!openxps.utils.Function")