Coverage for openxps/couplings/base.py: 97%

141 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 22:08 +0000

1""" 

2Base class for couplings. 

3 

4.. module:: openxps.couplings.base 

5 :platform: Linux, MacOS, Windows 

6 :synopsis: Base class for couplings between physical and extended phase-space systems 

7 

8.. classauthor:: Charlles Abreu <craabreu@gmail.com> 

9 

10""" 

11 

12import typing as t 

13 

14import cvpack 

15import openmm as mm 

16from cvpack.serialization import Serializable 

17from openmm import XmlSerializer 

18from openmm import _openmm as mmswig 

19from openmm import unit as mmunit 

20 

21from ..dynamical_variable import DynamicalVariable 

22 

23 

24class Coupling(Serializable): 

25 """Abstract base class for couplings between physical and extension systems. 

26 

27 A coupling connects the physical system's coordinates to the extension system's 

28 dynamical variables, enabling enhanced sampling simulations. 

29 

30 Subclasses must implement the :meth:`addToPhysicalSystem` and 

31 :meth:`addToExtensionSystem` methods. 

32 

33 Parameters 

34 ---------- 

35 forces 

36 A sequence of :OpenMM:`Force` objects. 

37 dynamical_variables 

38 A sequence of :class:`DynamicalVariable` objects. 

39 """ 

40 

41 def __init__( 

42 self, 

43 forces: t.Iterable[mm.Force], 

44 dynamical_variables: t.Sequence[DynamicalVariable], 

45 ) -> None: 

46 self._forces = list(forces) 

47 self._dynamical_variables = [dv.in_md_units() for dv in dynamical_variables] 

48 unique_names = {dv.name for dv in self._dynamical_variables} 

49 if len(unique_names) != len(self._dynamical_variables): 

50 raise ValueError("The dynamical variables must have unique names.") 

51 self._dv_indices = { 

52 dv.name: index for index, dv in enumerate(self._dynamical_variables) 

53 } 

54 self._flipped_force = None 

55 self._checkGlobalParameters() 

56 

57 def __add__(self, other: "Coupling") -> "CouplingSum": 

58 return CouplingSum([self, other]) 

59 

60 def __copy__(self) -> "Coupling": 

61 new = self.__class__.__new__(self.__class__) 

62 new.__setstate__(self.__getstate__()) 

63 return new 

64 

65 def __getstate__(self) -> dict[str, t.Any]: 

66 return { 

67 "forces": self._forces, 

68 "dynamical_variables": self._dynamical_variables, 

69 "dv_indices": self._dv_indices, 

70 } 

71 

72 def __setstate__(self, state: dict[str, t.Any]) -> None: 

73 self._forces = state["forces"] 

74 self._dynamical_variables = state["dynamical_variables"] 

75 self._dv_indices = state["dv_indices"] 

76 self._flipped_force = None 

77 

78 def _checkGlobalParameters(self) -> dict[str, mmunit.Quantity]: 

79 parameters = {} 

80 for force in self._forces: 

81 force_parameters = {} 

82 for index in range(force.getNumGlobalParameters()): 

83 name = force.getGlobalParameterName(index) 

84 force_parameters[name] = force.getGlobalParameterDefaultValue(index) 

85 for key, value in force_parameters.items(): 

86 if key in parameters and parameters[key] != value: 

87 raise ValueError( 

88 f"Parameter {key} has conflicting default values in " 

89 f"coupling: {parameters[key]} != {value}" 

90 ) 

91 parameters[key] = value 

92 

93 def _createFlippedForce(self) -> mm.Force | None: 

94 return None 

95 

96 @staticmethod 

97 def _addForceToSystem(force: mm.Force, system: mm.System) -> None: 

98 if isinstance(force, cvpack.CollectiveVariable): 

99 force.addToSystem(system) 

100 else: 

101 system.addForce(force) 

102 

103 def _updateDynamicalVariableIndices( 

104 self, dynamical_variables: t.Sequence[DynamicalVariable] 

105 ) -> None: 

106 """Update the indices of the dynamical variables associated with this coupling. 

107 

108 Parameters 

109 ---------- 

110 dynamical_variables 

111 All the dynamical variables in the system, regardless of whether they are 

112 associated with this coupling or not. 

113 """ 

114 for index, dv in enumerate(dynamical_variables): 

115 if dv.name in self._dv_indices: 

116 self._dv_indices[dv.name] = index 

117 

118 def getForces(self) -> list[mm.Force]: 

119 """Get the list of OpenMM Force objects associated with this coupling. 

120 

121 Returns 

122 ------- 

123 list[openmm.Force] 

124 A list of Force objects contained within this coupling. 

125 """ 

126 return self._forces 

127 

128 def getDynamicalVariables(self) -> t.Sequence[DynamicalVariable]: 

129 """Get the dynamical variables associated with this coupling. 

130 

131 Returns 

132 ------- 

133 list[DynamicalVariable] 

134 A list of DynamicalVariable objects contained within this coupling. 

135 """ 

136 return self._dynamical_variables 

137 

138 def getForce(self, index: int) -> mm.Force: 

139 """Retrieve a single OpenMM Force object from this coupling. 

140 

141 Parameters 

142 ---------- 

143 index 

144 The index of the Force object to retrieve. 

145 

146 Returns 

147 ------- 

148 openmm.Force 

149 The Force object at the specified index. 

150 

151 Raises 

152 ------ 

153 IndexError 

154 If the index is out of range. 

155 """ 

156 return self._forces[index] 

157 

158 def getDynamicalVariable(self, index: int) -> DynamicalVariable: 

159 """Retrieve a single dynamical variable from this coupling. 

160 

161 Parameters 

162 ---------- 

163 index 

164 The index of the DynamicalVariable object to retrieve. 

165 

166 Returns 

167 ------- 

168 DynamicalVariable 

169 The DynamicalVariable object at the specified index. 

170 

171 Raises 

172 ------ 

173 IndexError 

174 If the index is out of range. 

175 """ 

176 return self._dynamical_variables[index] 

177 

178 def getProtectedParameters(self) -> set[str]: 

179 """Get parameters of the physical context that should not be manually modified. 

180 

181 Returns 

182 ------- 

183 set[str] 

184 The protected parameters. 

185 """ 

186 if self._flipped_force is None: 

187 raise ValueError("This coupling has not been added to an extension system.") 

188 return { 

189 self._flipped_force.getCollectiveVariableName(index) 

190 for index in range(self._flipped_force.getNumCollectiveVariables()) 

191 } 

192 

193 def addToPhysicalSystem(self, system: mm.System) -> None: 

194 """Add this coupling to an OpenMM system. 

195 

196 Parameters 

197 ---------- 

198 system 

199 The system to which the coupling should be added. 

200 

201 Raises 

202 ------ 

203 NotImplementedError 

204 This method must be implemented by subclasses. 

205 

206 """ 

207 for force in self._forces: 

208 self._addForceToSystem(force, system) 

209 

210 def addToExtensionSystem(self, system: mm.System) -> None: 

211 """ 

212 Add the flipped version of this coupling to the extension system. 

213 

214 The flipped force replaces dynamical variable parameters with collective 

215 variables that represent the dynamical variables as particles. Physical 

216 collective variables become parameters set to zero. 

217 

218 Parameters 

219 ---------- 

220 extension_system 

221 The extension system to which the flipped coupling should be added. 

222 

223 Examples 

224 -------- 

225 >>> import cvpack 

226 >>> import openxps as xps 

227 >>> from openmm import unit 

228 >>> from math import pi 

229 >>> import openmm as mm 

230 >>> phi = cvpack.Torsion(6, 8, 14, 16, name="phi") 

231 >>> phi0 = xps.DynamicalVariable( 

232 ... "phi0", 

233 ... unit.radian, 

234 ... 3 * unit.dalton * (unit.nanometer / unit.radian)**2, 

235 ... xps.CircularBounds() 

236 ... ) 

237 >>> coupling = xps.CollectiveVariableCoupling( 

238 ... "0.5*kappa*min(delta,{2*pi}-delta)^2; delta=abs(phi-phi0)", 

239 ... [phi], 

240 ... [phi0], 

241 ... kappa=1000 * unit.kilojoules_per_mole / unit.radian**2, 

242 ... ) 

243 >>> extension_system = mm.System() 

244 >>> extension_system.addParticle(phi0.mass / phi0.mass.unit) 

245 0 

246 >>> coupling.addToExtensionSystem(extension_system) 

247 >>> extension_system.getNumForces() 

248 1 

249 """ 

250 self._flipped_force = self._createFlippedForce() 

251 self._addForceToSystem(self._flipped_force, system) 

252 

253 def updatePhysicalContext( 

254 self, 

255 physical_context: mm.Context, 

256 extension_context: mm.Context, 

257 ) -> None: 

258 """Update the physical context with the current extension parameters. 

259 

260 Parameters 

261 ---------- 

262 physical_context 

263 The physical context to update with the extension parameters. 

264 extension_context 

265 The extension context to get the extension parameters from. 

266 

267 """ 

268 collective_variables = mmswig.CustomCVForce_getCollectiveVariableValues( 

269 self._flipped_force, extension_context 

270 ) 

271 for index, value in enumerate(collective_variables): 

272 mmswig.Context_setParameter( 

273 physical_context, 

274 self._flipped_force.getCollectiveVariableName(index), 

275 value, 

276 ) 

277 

278 def updateExtensionContext( 

279 self, 

280 physical_context: mm.Context, 

281 extension_context: mm.Context, 

282 ) -> None: 

283 """Update the extension context with the current physical parameters. 

284 

285 Parameters 

286 ---------- 

287 physical_context 

288 The physical context to get the physical parameters from. 

289 extension_context 

290 The extension context to update with the physical parameters. 

291 

292 """ 

293 raise NotImplementedError("Subclasses must implement this method.") 

294 

295 def getCollectiveVariableValues( 

296 self, physical_context: mm.Context 

297 ) -> dict[str, float]: 

298 """Get the values of the collective variables. 

299 

300 Parameters 

301 ---------- 

302 physical_context 

303 The physical context to get the collective variable values from. 

304 """ 

305 collective_variables = {} 

306 for force in self._forces: 

307 if isinstance(force, mm.CustomCVForce): 

308 cv_values = force.getCollectiveVariableValues(physical_context) 

309 for index, value in enumerate(cv_values): 

310 collective_variables[ 

311 mmswig.CustomCVForce_getCollectiveVariableName(force, index) 

312 ] = value 

313 return collective_variables 

314 

315 

316class CouplingSum(Coupling): 

317 """A sum of couplings. 

318 

319 Parameters 

320 ---------- 

321 couplings 

322 The couplings to be added. 

323 

324 Examples 

325 -------- 

326 >>> from copy import copy 

327 >>> import cvpack 

328 >>> import openxps as xps 

329 >>> from openmm import unit 

330 >>> phi = cvpack.Torsion(6, 8, 14, 16, name="phi") 

331 >>> psi = cvpack.Torsion(4, 6, 8, 14, name="psi") 

332 >>> dv_mass = 3 * unit.dalton * (unit.nanometer / unit.radian)**2 

333 >>> phi_s = xps.DynamicalVariable( 

334 ... "phi_s", unit.radian, dv_mass, xps.CircularBounds() 

335 ... ) 

336 >>> psi_s = xps.DynamicalVariable( 

337 ... "psi_s", unit.radian, dv_mass, xps.CircularBounds() 

338 ... ) 

339 >>> coupling = xps.HarmonicCoupling( 

340 ... phi, phi_s, 1000 * unit.kilojoule_per_mole / unit.radian**2 

341 ... ) + xps.HarmonicCoupling( 

342 ... psi, psi_s, 500 * unit.kilojoule_per_mole / unit.radian**2 

343 ... ) 

344 """ 

345 

346 def __init__(self, couplings: t.Iterable[Coupling]) -> None: 

347 self._couplings = [] 

348 forces = [] 

349 dv_dict = {} 

350 for coupling in couplings: 

351 if isinstance(coupling, CouplingSum): 

352 self._couplings.extend(coupling.getCouplings()) 

353 else: 

354 self._couplings.append(coupling) 

355 forces.extend(coupling.getForces()) 

356 for dv in coupling.getDynamicalVariables(): 

357 if dv.name not in dv_dict: 

358 dv_dict[dv.name] = dv 

359 elif dv_dict[dv.name] != dv: 

360 raise ValueError( 

361 f'The dynamical variable "{dv.name}" has ' 

362 "conflicting definitions in the couplings." 

363 ) 

364 super().__init__(forces, sorted(dv_dict.values(), key=lambda dv: dv.name)) 

365 self._broadcastDynamicalVariableIndices() 

366 self._checkCollectiveVariables() 

367 

368 def __repr__(self) -> str: 

369 return "+".join(f"({repr(coupling)})" for coupling in self._couplings) 

370 

371 def __copy__(self) -> "CouplingSum": 

372 new = CouplingSum.__new__(CouplingSum) 

373 new.__setstate__(self.__getstate__()) 

374 return new 

375 

376 def __getstate__(self) -> dict[str, t.Any]: 

377 return {"couplings": self._couplings} 

378 

379 def __setstate__(self, state: dict[str, t.Any]) -> None: 

380 self.__init__(state["couplings"]) 

381 

382 def _broadcastDynamicalVariableIndices(self) -> None: 

383 for coupling in self._couplings: 

384 coupling._updateDynamicalVariableIndices(self._dynamical_variables) 

385 

386 def _checkCollectiveVariables(self) -> None: 

387 cvs = {} 

388 for coupling in self._couplings: 

389 for force in coupling.getForces(): 

390 if isinstance(force, mm.CustomCVForce): 

391 for index in range(force.getNumCollectiveVariables()): 

392 name = force.getCollectiveVariableName(index) 

393 xml_string = XmlSerializer.serialize( 

394 force.getCollectiveVariable(index) 

395 ) 

396 if name in cvs and cvs[name] != xml_string: 

397 raise ValueError( 

398 f'The collective variable "{name}" has conflicting ' 

399 "definitions in the couplings." 

400 ) 

401 cvs[name] = xml_string 

402 

403 def getCouplings(self) -> t.Sequence[Coupling]: 

404 """Get the couplings included in the summed coupling.""" 

405 return self._couplings 

406 

407 def getProtectedParameters(self) -> set[str]: 

408 return set.union( 

409 *[coupling.getProtectedParameters() for coupling in self._couplings] 

410 ) 

411 

412 def addToExtensionSystem(self, system: mm.System) -> None: 

413 for coupling in self._couplings: 

414 coupling.addToExtensionSystem(system) 

415 

416 def updatePhysicalContext( 

417 self, 

418 physical_context: mm.Context, 

419 extension_context: mm.Context, 

420 ): 

421 for coupling in self._couplings: 

422 coupling.updatePhysicalContext(physical_context, extension_context) 

423 

424 def updateExtensionContext( 

425 self, 

426 physical_context: mm.Context, 

427 extension_context: mm.Context, 

428 ): 

429 for coupling in self._couplings: 

430 coupling.updateExtensionContext(physical_context, extension_context) 

431 

432 

433CouplingSum.registerTag("!openxps.CouplingSum")