Coverage for openxps/integrators/utils.py: 91%

78 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 22:08 +0000

1""" 

2.. module:: openxps.integrators.utils 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: A mixin for integrators that provides extra functionality. 

5 

6.. classauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import textwrap 

11import typing as t 

12 

13import openmm as mm 

14from cvpack.units import Quantity 

15from openmm import unit as mmunit 

16 

17from openxps.utils import preprocess_args 

18 

19BLOCK_START = (6, 7) 

20BLOCK_END = 8 

21 

22T = t.TypeVar("T") 

23 

24 

25class IntegratorMixin: 

26 """A mixin for integrators that provides extra functionality.""" 

27 

28 _forceFirst: bool = False 

29 

30 def __repr__(self) -> str: 

31 """Return a human-readable version of each integrator step.""" 

32 readable_lines = [] 

33 

34 self.getNumPerDofVariables() > 0 and readable_lines.append("Per-dof variables:") 

35 per_dof = [] 

36 for index in range(self.getNumPerDofVariables()): 

37 per_dof.append(self.getPerDofVariableName(index)) 

38 readable_lines.append(" " + ", ".join(per_dof)) 

39 

40 self.getNumGlobalVariables() > 0 and readable_lines.append("Global variables:") 

41 for index in range(self.getNumGlobalVariables()): 

42 name = self.getGlobalVariableName(index) 

43 value = self.getGlobalVariable(index) 

44 readable_lines.append(f" {name} = {value}") 

45 

46 readable_lines.append("Computation steps:") 

47 

48 step_type_str = [ 

49 "{target} <- {expr}", 

50 "{target} <- {expr}", 

51 "{target} <- sum({expr})", 

52 "constrain positions", 

53 "constrain velocities", 

54 "allow forces to update the context state", 

55 "if ({expr}):", 

56 "while ({expr}):", 

57 "end", 

58 ] 

59 indent_level = 0 

60 for step in range(self.getNumComputations()): 

61 line = "" 

62 step_type, target, expr = self.getComputationStep(step) 

63 command = step_type_str[step_type].format(target=target, expr=expr) 

64 if step_type == BLOCK_END: 

65 indent_level -= 1 

66 line += f"{step:4d}: " + " " * indent_level + command 

67 if step_type in BLOCK_START: 

68 indent_level += 1 

69 readable_lines.append(line) 

70 return "\n".join(readable_lines) 

71 

72 def registerWithSystem(self, system: mm.System, isExtension: bool) -> None: 

73 """Register the integrator with the system.""" 

74 pass 

75 

76 @staticmethod 

77 def _countDegreesOfFreedom(system: mm.System) -> int: 

78 """Count the degrees of freedom in a system. 

79 

80 Parameters 

81 ---------- 

82 system 

83 The :OpenMM:`System` to count the degrees of freedom of. 

84 

85 Returns 

86 ------- 

87 int 

88 The number of degrees of freedom in the system. 

89 """ 

90 dof = 0 

91 for i in range(system.getNumParticles()): 

92 if system.getParticleMass(i) > 0 * mmunit.dalton: 

93 dof += 3 

94 for i in range(system.getNumConstraints()): 

95 p1, p2, _ = system.getConstraintParameters(i) 

96 if (system.getParticleMass(p1) > 0 * mmunit.dalton) or ( 

97 system.getParticleMass(p2) > 0 * mmunit.dalton 

98 ): 

99 dof -= 1 

100 if any( 

101 isinstance(system.getForce(i), mm.CMMotionRemover) 

102 for i in range(system.getNumForces()) 

103 ): 

104 dof -= 3 

105 return dof 

106 

107 def isForceFirst(self) -> bool: 

108 """Check if the integrator follows a force-first scheme. 

109 

110 Returns 

111 ------- 

112 bool 

113 True if the integrator is force-first, False otherwise. 

114 """ 

115 return self._forceFirst 

116 

117 

118def add_property(property: str, unit: mmunit.Unit) -> t.Callable[[type[T]], type[T]]: 

119 """A decorator to add a property to the integrator. 

120 

121 The property is added to the integrator as a class attribute with the name 

122 `_{snake_case_name}`. Three methods are added to the class: 

123 

124 * `_init_{snake_case_name}(value)`: Initializes the property. 

125 * `set{PascalCaseName}(value)`: Sets the property and updates the global variables 

126 if the integrator has a `_update_global_variables` method. 

127 * `get{PascalCaseName}()`: Gets the property. 

128 

129 Parameters 

130 ---------- 

131 property 

132 The name of the property. 

133 unit 

134 The unit of the property. 

135 """ 

136 if unit.in_unit_system(mmunit.md_unit_system) != unit: 

137 raise ValueError(f"The unit of {property} must be in the MD unit system.") 

138 PascalCaseName = "".join(s.capitalize() for s in property.split()) 

139 snake_case_name = "_".join(s.lower() for s in property.split()) 

140 

141 def set_value(self, value: t.Union[mmunit.Quantity, float]) -> None: 

142 args = [value] if mmunit.is_quantity(value) else [value, unit] 

143 setattr(self, f"_{snake_case_name}", Quantity(*args)) 

144 

145 @preprocess_args 

146 def set_and_update(self, value: t.Union[mmunit.Quantity, float]) -> None: 

147 set_value(value) 

148 if hasattr(self, "_update_global_variables"): 

149 self._update_global_variables() 

150 

151 set_and_update.__doc__ = textwrap.dedent( 

152 f"""\ 

153 Set the {property}. 

154 

155 Parameters 

156 ---------- 

157 value 

158 The {property}. 

159 """ 

160 ) 

161 

162 def get_value(self) -> mmunit.Quantity: 

163 return getattr(self, f"_{snake_case_name}") 

164 

165 get_value.__doc__ = textwrap.dedent( 

166 f"""\ 

167 Get the {property}. 

168 

169 Returns 

170 ------- 

171 openmm.unit.Quantity 

172 The {property}. 

173 """ 

174 ) 

175 

176 def decorator(cls: type[T]) -> type[T]: 

177 setattr(cls, f"_{snake_case_name}", None) 

178 setattr(cls, f"_init_{snake_case_name}", set_value) 

179 setattr(cls, f"set{PascalCaseName}", set_and_update) 

180 setattr(cls, f"get{PascalCaseName}", get_value) 

181 return cls 

182 

183 return decorator