Coverage for openxps/regression.py: 99%

135 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-13 22:08 +0000

1""" 

2.. module:: openxps.rbf 

3 :platform: Linux, MacOS, Windows 

4 :synopsis: Radial basis function regressor for free energy surface fitting. 

5 

6.. classauthor:: Charlles Abreu <craabreu@gmail.com> 

7 

8""" 

9 

10import os 

11import tempfile 

12import typing as t 

13 

14import numpy as np 

15import pandas as pd 

16import torch 

17from lightning import pytorch as pl 

18from torch import nn 

19from torch.utils.data import DataLoader, TensorDataset 

20 

21from .bounds import PeriodicBounds 

22from .dynamical_variable import DynamicalVariable 

23 

24if t.TYPE_CHECKING: 

25 from numpy.typing import ArrayLike 

26else: 

27 

28 class ArrayLike(t.Protocol): ... 

29 

30 

31INIT_NUM_SIGMAS = 4.0 

32 

33 

34class RBFPotential(nn.Module): 

35 """Radial basis function potential. 

36 

37 Parameters 

38 ---------- 

39 dynamical_variables 

40 A sequence of dynamical variables defining the dimensions and bounds 

41 for the potential. The length of this sequence determines the input 

42 dimension. 

43 M 

44 The number of radial basis functions. 

45 learn_centers 

46 Whether to learn the centers of the radial basis functions. 

47 """ 

48 

49 def __init__( 

50 self, 

51 dynamical_variables: t.Sequence[DynamicalVariable], 

52 M: int = 256, 

53 learn_centers: bool = True, 

54 ) -> None: 

55 super().__init__() 

56 lengths, periodic, random_points = [], [], [] 

57 for dv in dynamical_variables: 

58 lengths.append(dv.bounds.length) 

59 periodic.append(isinstance(dv.bounds, PeriodicBounds)) 

60 random_points.append(torch.randn(M) * dv.bounds.length + dv.bounds.lower) 

61 

62 self.c = nn.Parameter( 

63 torch.stack(random_points, dim=-1), requires_grad=learn_centers 

64 ) 

65 log_sigmas = np.log(lengths) - np.log(INIT_NUM_SIGMAS) 

66 self.logsig = nn.Parameter( 

67 torch.tensor(np.repeat(log_sigmas[None, :], M, axis=0), dtype=torch.float32) 

68 ) 

69 self.w = nn.Parameter(torch.zeros(M)) 

70 self._length_over_pi = nn.Parameter( 

71 torch.tensor(np.array(lengths) / np.pi, dtype=torch.float32)[None, None, :], 

72 requires_grad=False, 

73 ) 

74 self._periodic = nn.Parameter( 

75 torch.tensor(periodic, dtype=torch.bool)[None, None, :], requires_grad=False 

76 ) 

77 

78 def _delta2_fn(self, disp: torch.Tensor) -> torch.Tensor: 

79 return torch.where( 

80 self._periodic, 

81 self._length_over_pi * torch.sin(disp / self._length_over_pi), 

82 disp, 

83 ).square() 

84 

85 def _delta2_grad(self, disp: torch.Tensor) -> torch.Tensor: 

86 return torch.where( 

87 self._periodic, 

88 self._length_over_pi * torch.sin(2 * disp / self._length_over_pi), 

89 2 * disp, 

90 ) 

91 

92 def _phi(self, disp: torch.Tensor) -> torch.Tensor: 

93 delta2 = self._delta2_fn(disp) 

94 sigma2 = torch.exp(2 * self.logsig) 

95 return torch.exp(-0.5 * (delta2 / sigma2[None, :, :]).sum(-1)) 

96 

97 def forward(self, x: torch.Tensor) -> torch.Tensor: 

98 disp = x[:, None, :] - self.c[None, :, :] 

99 return self._phi(disp) @ self.w 

100 

101 def grad(self, x: torch.Tensor) -> torch.Tensor: 

102 disp = x[:, None, :] - self.c[None, :, :] 

103 Phi = self._phi(disp) 

104 sigma2 = torch.exp(2 * self.logsig) 

105 fac = (-0.5 * self.w[:, None] / sigma2)[None, :, :] 

106 return (fac * Phi[:, :, None] * self._delta2_grad(disp)).sum(1) 

107 

108 

109class GradMatch(pl.LightningModule): 

110 """Gradient matching regressor. 

111 

112 Parameters 

113 ---------- 

114 dynamical_variables 

115 A sequence of dynamical variables defining the dimensions and bounds 

116 for the potential. The length of this sequence determines the input 

117 dimension. 

118 

119 Keyword Arguments 

120 ----------------- 

121 M 

122 The number of radial basis functions. 

123 lr 

124 The learning rate. 

125 wd 

126 The weight decay. 

127 """ 

128 

129 def __init__( 

130 self, 

131 dynamical_variables: t.Sequence[DynamicalVariable], 

132 M: int = 256, 

133 lr: float = 2e-3, 

134 wd: float = 1e-4, 

135 ) -> None: 

136 super().__init__() 

137 self.save_hyperparameters() 

138 self.f = RBFPotential(dynamical_variables, M) 

139 

140 def configure_optimizers(self) -> torch.optim.Optimizer: 

141 return torch.optim.AdamW( 

142 self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.wd 

143 ) 

144 

145 def _loss(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: 

146 x, G = batch 

147 return ((self.f.grad(x) - G) ** 2).mean() 

148 

149 def training_step( 

150 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 

151 ) -> torch.Tensor: 

152 loss = self._loss(batch) 

153 self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) 

154 return loss 

155 

156 def validation_step( 

157 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int 

158 ) -> torch.Tensor: 

159 loss = self._loss(batch) 

160 self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) 

161 return loss 

162 

163 

164class ForceMatchingRegressor: 

165 r"""Potential regressor from sampled position/force pairs. 

166 

167 A potential function in a :math:`d`-dimensional variable space is approximated as: 

168 

169 .. math:: 

170 U({\bf s}) = \sum_{m=1}^n w_m \exp\left( 

171 -\frac{1}{2} \sum_{k=1}^d \frac{\delta_k^2(s_k - c_{m,k})}{\sigma_{m,k}^2} 

172 \right) 

173 

174 where the :math:`n` weights :math:`w_m`, kernel bandwidths :math:`\sigma_m`, and 

175 kernel centers :math:`{\bf c}_m` are adjustable parameters. The displacement 

176 function :math:`\delta_k(x)` depends on the periodicity of the variable :math:`s_k` 

177 and is defined as: 

178 

179 .. math:: 

180 \delta_k(x) = \left\{\begin{array}{ll} 

181 x & \text{if }s_k\text{ is non-periodic} \\ 

182 \frac{L_k}{\pi}\sin(\frac{\pi x}{L_k}) & \text{if }s_k\text{ has period }L_k 

183 \end{array}\right. 

184 

185 Given :math:`N` position-force pairs :math:`({\bf s}_i, {\bf F}_i)`, the parameters 

186 are adjusted by minimizing the mean squared error loss: 

187 

188 .. math:: 

189 L = \frac{1}{N} \sum_{i=1}^N \left\| {\bf f}({\bf s}_i) - {\bf F}_i \right\|^2, 

190 

191 where :math:`{\bf f}({\bf s}) = -\nabla_{\bf s} U({\bf s})` is the predicted force. 

192 

193 Parameters 

194 ---------- 

195 dynamical_variables 

196 The dynamical variables to be used in the potential. 

197 num_kernels 

198 The number of kernels to be used in the potential. 

199 

200 Keyword Arguments 

201 ----------------- 

202 validation_fraction 

203 The fraction of the data to be used for validation. 

204 batch_size 

205 The batch size to be used for training. 

206 num_epochs 

207 The number of epochs to be used for training. 

208 patience 

209 The patience for early stopping. 

210 learning_rate 

211 The learning rate to be used for training. 

212 weight_decay 

213 The weight decay to be used for training. 

214 accelerator 

215 The accelerator to be used for training. Valid options can be found 

216 `here <https://lightning.ai/docs/pytorch/LTS/common/trainer.html#accelerator>`_. 

217 num_workers 

218 The number of data loading workers to be used for training. 

219 """ 

220 

221 def __init__( # noqa: PLR0913 

222 self, 

223 dynamical_variables: t.Sequence[DynamicalVariable], 

224 num_kernels: int, 

225 *, 

226 validation_fraction: float = 0.1, 

227 batch_size: int = 256, 

228 num_epochs: int = 50, 

229 patience: int = 10, 

230 learning_rate: float = 2e-3, 

231 weight_decay: float = 1e-4, 

232 accelerator: str = "auto", 

233 num_workers: int = 0, 

234 ): 

235 self._dynamical_variables = tuple( 

236 dv.in_md_units() for dv in dynamical_variables 

237 ) 

238 self._val_frac = validation_fraction 

239 self._batch_size = batch_size 

240 self._num_epochs = num_epochs 

241 self._patience = patience 

242 self._lr = learning_rate 

243 self._wd = weight_decay 

244 self._accelerator = accelerator 

245 self._num_workers = num_workers 

246 self._M = num_kernels 

247 self._model = None 

248 self._metrics = None 

249 

250 @staticmethod 

251 def _as_numpy(x: torch.Tensor) -> np.ndarray: 

252 return x.detach().cpu().numpy() 

253 

254 def _create_dataloaders( 

255 self, X: torch.Tensor, G: torch.Tensor 

256 ) -> tuple[DataLoader, DataLoader | None]: 

257 N = len(X) 

258 idx = torch.randperm(N) 

259 nval = int(round(self._val_frac * N)) 

260 val_idx, tr_idx = idx[:nval], idx[nval:] 

261 train_dl = DataLoader( 

262 TensorDataset(X[tr_idx], G[tr_idx]), 

263 batch_size=self._batch_size, 

264 shuffle=True, 

265 num_workers=self._num_workers, 

266 persistent_workers=self._num_workers > 0, 

267 ) 

268 if len(val_idx) == 0: 

269 return train_dl, None 

270 val_dl = DataLoader( 

271 TensorDataset(X[val_idx], G[val_idx]), 

272 batch_size=self._batch_size, 

273 num_workers=self._num_workers, 

274 persistent_workers=self._num_workers > 0, 

275 ) 

276 return train_dl, val_dl 

277 

278 def _create_callbacks( 

279 self, directory: str 

280 ) -> tuple[pl.callbacks.ModelCheckpoint, pl.callbacks.EarlyStopping]: 

281 checkpoint = pl.callbacks.ModelCheckpoint( 

282 monitor="val_loss", 

283 save_last=False, 

284 filename="best-{epoch:02d}-{val_loss:.6f}", 

285 dirpath=directory, 

286 ) 

287 early_stopping = pl.callbacks.EarlyStopping( 

288 monitor="val_loss", patience=self._patience 

289 ) 

290 return checkpoint, early_stopping 

291 

292 def fit( 

293 self, positions: ArrayLike, forces: ArrayLike, *, seed: t.Optional[int] = None 

294 ) -> None: 

295 """Fit the potential regressor to the given positions and forces. 

296 

297 Parameters 

298 ---------- 

299 positions 

300 The positions to be used for training. 

301 forces 

302 The forces to be used for training. 

303 

304 Keyword Arguments 

305 ----------------- 

306 seed 

307 The seed to be used for training. If None, a random seed is generated. 

308 """ 

309 np.random.seed(seed) 

310 torch.manual_seed(np.random.randint(0, 1000000)) 

311 

312 X = torch.as_tensor(positions, dtype=torch.float32) 

313 G = -torch.as_tensor(forces, dtype=torch.float32) 

314 train_dl, val_dl = self._create_dataloaders(X, G) 

315 

316 self._model = GradMatch(self._dynamical_variables, self._M, self._lr, self._wd) 

317 

318 with tempfile.TemporaryDirectory() as tmp_dir: 

319 logger = pl.loggers.CSVLogger(save_dir=tmp_dir, version=0) 

320 

321 trainer_kwargs = { 

322 "max_epochs": self._num_epochs, 

323 "accelerator": self._accelerator, 

324 "devices": "auto", 

325 "logger": logger, 

326 "enable_progress_bar": True, 

327 "enable_model_summary": True, 

328 } 

329 

330 if val_dl is None: 

331 trainer_kwargs["enable_checkpointing"] = False 

332 trainer = pl.Trainer(**trainer_kwargs) 

333 trainer.fit(self._model, train_dl) 

334 else: 

335 checkpoint, early_stopping = self._create_callbacks(tmp_dir) 

336 trainer_kwargs["callbacks"] = [checkpoint, early_stopping] 

337 trainer = pl.Trainer(**trainer_kwargs) 

338 trainer.fit(self._model, train_dl, val_dl) 

339 if checkpoint.best_model_path: 

340 self._model = GradMatch.load_from_checkpoint( 

341 checkpoint.best_model_path 

342 ) 

343 

344 path = os.path.join(logger.log_dir, "metrics.csv") 

345 if os.path.exists(path): 

346 df = pd.read_csv(path) 

347 self._metrics = df[df.train_loss.notna()].copy() 

348 if val_dl is not None: 

349 self._metrics["val_loss"] = df.val_loss[df.val_loss.notna()].values 

350 

351 def predict(self, positions: ArrayLike) -> np.ndarray: 

352 """Predict the potential at the given positions. 

353 

354 Parameters 

355 ---------- 

356 positions 

357 An array of shape (M, d) containing the positions at which to predict the 

358 potential. 

359 

360 Returns 

361 ------- 

362 np.ndarray 

363 An array of shape (M,) containing the potential at the given positions. 

364 """ 

365 if self._model is None: 

366 raise RuntimeError("The model has not been fitted yet.") 

367 device = next(self._model.parameters()).device 

368 X = torch.as_tensor(positions, dtype=torch.float32).to(device) 

369 with torch.no_grad(): 

370 return self._as_numpy(self._model.f(X)) 

371 

372 def get_parameters(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: 

373 """Get the parameters of the potential regressor. 

374 

375 Returns 

376 ------- 

377 centers 

378 The centers of the radial basis functions, shape (M, d). 

379 sigmas 

380 The bandwidths of the radial basis functions per dimension, 

381 shape (M, d). 

382 weights 

383 The weights of the radial basis functions, shape (M,). 

384 """ 

385 centers = self._as_numpy(self._model.f.c) 

386 for dv, column in zip(self._dynamical_variables, centers.T): 

387 column[:], _ = np.vectorize(dv.bounds.wrap)(column, 0) 

388 return ( 

389 centers, 

390 self._as_numpy(self._model.f.logsig.exp()), 

391 self._as_numpy(self._model.f.w), 

392 ) 

393 

394 def get_learning_curve(self) -> pd.DataFrame: 

395 """Get the learning curve of the potential regressor. 

396 

397 Returns 

398 ------- 

399 pd.DataFrame 

400 The learning curve of the potential regressor. 

401 """ 

402 return self._metrics