Coverage for openxps/regression.py: 99%
135 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-13 22:08 +0000
1"""
2.. module:: openxps.rbf
3 :platform: Linux, MacOS, Windows
4 :synopsis: Radial basis function regressor for free energy surface fitting.
6.. classauthor:: Charlles Abreu <craabreu@gmail.com>
8"""
10import os
11import tempfile
12import typing as t
14import numpy as np
15import pandas as pd
16import torch
17from lightning import pytorch as pl
18from torch import nn
19from torch.utils.data import DataLoader, TensorDataset
21from .bounds import PeriodicBounds
22from .dynamical_variable import DynamicalVariable
24if t.TYPE_CHECKING:
25 from numpy.typing import ArrayLike
26else:
28 class ArrayLike(t.Protocol): ...
31INIT_NUM_SIGMAS = 4.0
34class RBFPotential(nn.Module):
35 """Radial basis function potential.
37 Parameters
38 ----------
39 dynamical_variables
40 A sequence of dynamical variables defining the dimensions and bounds
41 for the potential. The length of this sequence determines the input
42 dimension.
43 M
44 The number of radial basis functions.
45 learn_centers
46 Whether to learn the centers of the radial basis functions.
47 """
49 def __init__(
50 self,
51 dynamical_variables: t.Sequence[DynamicalVariable],
52 M: int = 256,
53 learn_centers: bool = True,
54 ) -> None:
55 super().__init__()
56 lengths, periodic, random_points = [], [], []
57 for dv in dynamical_variables:
58 lengths.append(dv.bounds.length)
59 periodic.append(isinstance(dv.bounds, PeriodicBounds))
60 random_points.append(torch.randn(M) * dv.bounds.length + dv.bounds.lower)
62 self.c = nn.Parameter(
63 torch.stack(random_points, dim=-1), requires_grad=learn_centers
64 )
65 log_sigmas = np.log(lengths) - np.log(INIT_NUM_SIGMAS)
66 self.logsig = nn.Parameter(
67 torch.tensor(np.repeat(log_sigmas[None, :], M, axis=0), dtype=torch.float32)
68 )
69 self.w = nn.Parameter(torch.zeros(M))
70 self._length_over_pi = nn.Parameter(
71 torch.tensor(np.array(lengths) / np.pi, dtype=torch.float32)[None, None, :],
72 requires_grad=False,
73 )
74 self._periodic = nn.Parameter(
75 torch.tensor(periodic, dtype=torch.bool)[None, None, :], requires_grad=False
76 )
78 def _delta2_fn(self, disp: torch.Tensor) -> torch.Tensor:
79 return torch.where(
80 self._periodic,
81 self._length_over_pi * torch.sin(disp / self._length_over_pi),
82 disp,
83 ).square()
85 def _delta2_grad(self, disp: torch.Tensor) -> torch.Tensor:
86 return torch.where(
87 self._periodic,
88 self._length_over_pi * torch.sin(2 * disp / self._length_over_pi),
89 2 * disp,
90 )
92 def _phi(self, disp: torch.Tensor) -> torch.Tensor:
93 delta2 = self._delta2_fn(disp)
94 sigma2 = torch.exp(2 * self.logsig)
95 return torch.exp(-0.5 * (delta2 / sigma2[None, :, :]).sum(-1))
97 def forward(self, x: torch.Tensor) -> torch.Tensor:
98 disp = x[:, None, :] - self.c[None, :, :]
99 return self._phi(disp) @ self.w
101 def grad(self, x: torch.Tensor) -> torch.Tensor:
102 disp = x[:, None, :] - self.c[None, :, :]
103 Phi = self._phi(disp)
104 sigma2 = torch.exp(2 * self.logsig)
105 fac = (-0.5 * self.w[:, None] / sigma2)[None, :, :]
106 return (fac * Phi[:, :, None] * self._delta2_grad(disp)).sum(1)
109class GradMatch(pl.LightningModule):
110 """Gradient matching regressor.
112 Parameters
113 ----------
114 dynamical_variables
115 A sequence of dynamical variables defining the dimensions and bounds
116 for the potential. The length of this sequence determines the input
117 dimension.
119 Keyword Arguments
120 -----------------
121 M
122 The number of radial basis functions.
123 lr
124 The learning rate.
125 wd
126 The weight decay.
127 """
129 def __init__(
130 self,
131 dynamical_variables: t.Sequence[DynamicalVariable],
132 M: int = 256,
133 lr: float = 2e-3,
134 wd: float = 1e-4,
135 ) -> None:
136 super().__init__()
137 self.save_hyperparameters()
138 self.f = RBFPotential(dynamical_variables, M)
140 def configure_optimizers(self) -> torch.optim.Optimizer:
141 return torch.optim.AdamW(
142 self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.wd
143 )
145 def _loss(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
146 x, G = batch
147 return ((self.f.grad(x) - G) ** 2).mean()
149 def training_step(
150 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
151 ) -> torch.Tensor:
152 loss = self._loss(batch)
153 self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
154 return loss
156 def validation_step(
157 self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
158 ) -> torch.Tensor:
159 loss = self._loss(batch)
160 self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
161 return loss
164class ForceMatchingRegressor:
165 r"""Potential regressor from sampled position/force pairs.
167 A potential function in a :math:`d`-dimensional variable space is approximated as:
169 .. math::
170 U({\bf s}) = \sum_{m=1}^n w_m \exp\left(
171 -\frac{1}{2} \sum_{k=1}^d \frac{\delta_k^2(s_k - c_{m,k})}{\sigma_{m,k}^2}
172 \right)
174 where the :math:`n` weights :math:`w_m`, kernel bandwidths :math:`\sigma_m`, and
175 kernel centers :math:`{\bf c}_m` are adjustable parameters. The displacement
176 function :math:`\delta_k(x)` depends on the periodicity of the variable :math:`s_k`
177 and is defined as:
179 .. math::
180 \delta_k(x) = \left\{\begin{array}{ll}
181 x & \text{if }s_k\text{ is non-periodic} \\
182 \frac{L_k}{\pi}\sin(\frac{\pi x}{L_k}) & \text{if }s_k\text{ has period }L_k
183 \end{array}\right.
185 Given :math:`N` position-force pairs :math:`({\bf s}_i, {\bf F}_i)`, the parameters
186 are adjusted by minimizing the mean squared error loss:
188 .. math::
189 L = \frac{1}{N} \sum_{i=1}^N \left\| {\bf f}({\bf s}_i) - {\bf F}_i \right\|^2,
191 where :math:`{\bf f}({\bf s}) = -\nabla_{\bf s} U({\bf s})` is the predicted force.
193 Parameters
194 ----------
195 dynamical_variables
196 The dynamical variables to be used in the potential.
197 num_kernels
198 The number of kernels to be used in the potential.
200 Keyword Arguments
201 -----------------
202 validation_fraction
203 The fraction of the data to be used for validation.
204 batch_size
205 The batch size to be used for training.
206 num_epochs
207 The number of epochs to be used for training.
208 patience
209 The patience for early stopping.
210 learning_rate
211 The learning rate to be used for training.
212 weight_decay
213 The weight decay to be used for training.
214 accelerator
215 The accelerator to be used for training. Valid options can be found
216 `here <https://lightning.ai/docs/pytorch/LTS/common/trainer.html#accelerator>`_.
217 num_workers
218 The number of data loading workers to be used for training.
219 """
221 def __init__( # noqa: PLR0913
222 self,
223 dynamical_variables: t.Sequence[DynamicalVariable],
224 num_kernels: int,
225 *,
226 validation_fraction: float = 0.1,
227 batch_size: int = 256,
228 num_epochs: int = 50,
229 patience: int = 10,
230 learning_rate: float = 2e-3,
231 weight_decay: float = 1e-4,
232 accelerator: str = "auto",
233 num_workers: int = 0,
234 ):
235 self._dynamical_variables = tuple(
236 dv.in_md_units() for dv in dynamical_variables
237 )
238 self._val_frac = validation_fraction
239 self._batch_size = batch_size
240 self._num_epochs = num_epochs
241 self._patience = patience
242 self._lr = learning_rate
243 self._wd = weight_decay
244 self._accelerator = accelerator
245 self._num_workers = num_workers
246 self._M = num_kernels
247 self._model = None
248 self._metrics = None
250 @staticmethod
251 def _as_numpy(x: torch.Tensor) -> np.ndarray:
252 return x.detach().cpu().numpy()
254 def _create_dataloaders(
255 self, X: torch.Tensor, G: torch.Tensor
256 ) -> tuple[DataLoader, DataLoader | None]:
257 N = len(X)
258 idx = torch.randperm(N)
259 nval = int(round(self._val_frac * N))
260 val_idx, tr_idx = idx[:nval], idx[nval:]
261 train_dl = DataLoader(
262 TensorDataset(X[tr_idx], G[tr_idx]),
263 batch_size=self._batch_size,
264 shuffle=True,
265 num_workers=self._num_workers,
266 persistent_workers=self._num_workers > 0,
267 )
268 if len(val_idx) == 0:
269 return train_dl, None
270 val_dl = DataLoader(
271 TensorDataset(X[val_idx], G[val_idx]),
272 batch_size=self._batch_size,
273 num_workers=self._num_workers,
274 persistent_workers=self._num_workers > 0,
275 )
276 return train_dl, val_dl
278 def _create_callbacks(
279 self, directory: str
280 ) -> tuple[pl.callbacks.ModelCheckpoint, pl.callbacks.EarlyStopping]:
281 checkpoint = pl.callbacks.ModelCheckpoint(
282 monitor="val_loss",
283 save_last=False,
284 filename="best-{epoch:02d}-{val_loss:.6f}",
285 dirpath=directory,
286 )
287 early_stopping = pl.callbacks.EarlyStopping(
288 monitor="val_loss", patience=self._patience
289 )
290 return checkpoint, early_stopping
292 def fit(
293 self, positions: ArrayLike, forces: ArrayLike, *, seed: t.Optional[int] = None
294 ) -> None:
295 """Fit the potential regressor to the given positions and forces.
297 Parameters
298 ----------
299 positions
300 The positions to be used for training.
301 forces
302 The forces to be used for training.
304 Keyword Arguments
305 -----------------
306 seed
307 The seed to be used for training. If None, a random seed is generated.
308 """
309 np.random.seed(seed)
310 torch.manual_seed(np.random.randint(0, 1000000))
312 X = torch.as_tensor(positions, dtype=torch.float32)
313 G = -torch.as_tensor(forces, dtype=torch.float32)
314 train_dl, val_dl = self._create_dataloaders(X, G)
316 self._model = GradMatch(self._dynamical_variables, self._M, self._lr, self._wd)
318 with tempfile.TemporaryDirectory() as tmp_dir:
319 logger = pl.loggers.CSVLogger(save_dir=tmp_dir, version=0)
321 trainer_kwargs = {
322 "max_epochs": self._num_epochs,
323 "accelerator": self._accelerator,
324 "devices": "auto",
325 "logger": logger,
326 "enable_progress_bar": True,
327 "enable_model_summary": True,
328 }
330 if val_dl is None:
331 trainer_kwargs["enable_checkpointing"] = False
332 trainer = pl.Trainer(**trainer_kwargs)
333 trainer.fit(self._model, train_dl)
334 else:
335 checkpoint, early_stopping = self._create_callbacks(tmp_dir)
336 trainer_kwargs["callbacks"] = [checkpoint, early_stopping]
337 trainer = pl.Trainer(**trainer_kwargs)
338 trainer.fit(self._model, train_dl, val_dl)
339 if checkpoint.best_model_path:
340 self._model = GradMatch.load_from_checkpoint(
341 checkpoint.best_model_path
342 )
344 path = os.path.join(logger.log_dir, "metrics.csv")
345 if os.path.exists(path):
346 df = pd.read_csv(path)
347 self._metrics = df[df.train_loss.notna()].copy()
348 if val_dl is not None:
349 self._metrics["val_loss"] = df.val_loss[df.val_loss.notna()].values
351 def predict(self, positions: ArrayLike) -> np.ndarray:
352 """Predict the potential at the given positions.
354 Parameters
355 ----------
356 positions
357 An array of shape (M, d) containing the positions at which to predict the
358 potential.
360 Returns
361 -------
362 np.ndarray
363 An array of shape (M,) containing the potential at the given positions.
364 """
365 if self._model is None:
366 raise RuntimeError("The model has not been fitted yet.")
367 device = next(self._model.parameters()).device
368 X = torch.as_tensor(positions, dtype=torch.float32).to(device)
369 with torch.no_grad():
370 return self._as_numpy(self._model.f(X))
372 def get_parameters(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
373 """Get the parameters of the potential regressor.
375 Returns
376 -------
377 centers
378 The centers of the radial basis functions, shape (M, d).
379 sigmas
380 The bandwidths of the radial basis functions per dimension,
381 shape (M, d).
382 weights
383 The weights of the radial basis functions, shape (M,).
384 """
385 centers = self._as_numpy(self._model.f.c)
386 for dv, column in zip(self._dynamical_variables, centers.T):
387 column[:], _ = np.vectorize(dv.bounds.wrap)(column, 0)
388 return (
389 centers,
390 self._as_numpy(self._model.f.logsig.exp()),
391 self._as_numpy(self._model.f.w),
392 )
394 def get_learning_curve(self) -> pd.DataFrame:
395 """Get the learning curve of the potential regressor.
397 Returns
398 -------
399 pd.DataFrame
400 The learning curve of the potential regressor.
401 """
402 return self._metrics