Coverage for cosmolayer / cosmolightning.py: 79%
156 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
1"""
2.. module:: cosmolayer.cosmolightning
3 :synopsis: PyTorch Lightning module for batched CosmoLayer training.
4"""
6from __future__ import annotations
8from collections.abc import Sequence
10import numpy as np
11import torch
12from lightning import pytorch as pl
13from numpy.typing import NDArray
14from torch import distributed as td
15from torch.nn import functional as F
16from torchmetrics import MeanAbsoluteError, MeanSquaredError, R2Score
18from .cosmodata import InputsType
19from .cosmolayer import CosmoLayer
20from .utils import is_loss_function
22EPSILON = 1e-8
25class LogGammaLightningModule(pl.LightningModule):
26 """PyTorch Lightning module for batched training of a learnable
27 :class:`~cosmolayer.CosmoLayer`.
29 This class is the canonical high-level training interface for CosmoLayer.
30 It constructs an internal :class:`~cosmolayer.CosmoLayer` with learnable
31 interaction matrices and defines the optimization, training, validation,
32 test, and prediction logic.
34 The targets are the log-activity coefficients of the components. In order to
35 handle other tasks, the user must subclass :class:`LogGammaLightningModule` and
36 override the :meth:`~LogGammaLightningModule.predict_from_log_gamma` method. For
37 instance:
39 .. code-block:: python
41 from scipy.constants import R
43 class ExcessGibbsLightningModule(LogGammaLightningModule):
44 def predict_from_log_gamma(self, T, x, log_gamma):
45 return (R * T * (x * log_gamma).sum(dim=-1)).unsqueeze(-1)
47 The module is batch-first throughout. All inputs must represent a minibatch
48 of ``b`` datapoints, and the returned predictions must have leading
49 dimension ``b``. Targets must have the same shape as the predictions.
51 Parameters
52 ----------
53 num_segment_types : int
54 Number of COSMO segment types.
55 temperature_exponents : tuple[int, ...]
56 Exponents defining the temperature dependence of the interaction
57 matrices.
58 area_per_segment : float
59 Area associated with one segment.
60 reference_temperature : float, optional
61 Reference temperature used by :class:`CosmoLayer`.
62 Default is ``298.15``.
63 max_iter : int, optional
64 Maximum number of internal fixed-point or iterative solver steps used
65 by :class:`CosmoLayer`. Default is ``100``.
66 learning_rate : float, optional
67 Learning rate for the Adam optimizer. Default is ``1e-3``.
68 weight_decay : float, optional
69 Weight decay for the Adam optimizer. Default is ``0.0``.
70 loss_function : str, optional
71 Loss function used in training, validation, and test steps. Must be a
72 valid loss function from :mod:`torch.nn.functional`.
73 Default is ``"mse_loss"``.
74 initialization : Sequence[NDArray[np.float64]] | int, optional
75 Initialization for the learnable interaction matrices.
77 - If an ``int`` is provided, it is interpreted as the random seed used
78 to sample one matrix per temperature exponent from a standard normal
79 distribution.
80 - If a sequence of NumPy arrays is provided, it must contain exactly
81 one array per temperature exponent, and each array must have shape
82 ``(num_segment_types, num_segment_types)``.
84 Default is ``42``.
86 Examples
87 --------
88 >>> import torch
89 >>> from importlib.resources import files
90 >>> import cosmolayer as cl
91 >>> from cosmolayer import cosmosac
92 >>> model = cosmosac.CosmoSac2010Model
93 >>> module = LogGammaLightningModule(
94 ... num_segment_types=model.num_segment_types,
95 ... temperature_exponents=model.temperature_exponents,
96 ... area_per_segment=model.area_per_segment,
97 ... )
98 >>> solute_path = files("cosmolayer.data") / "NCCO.cosmo"
99 >>> solvent_path = files("cosmolayer.data") / "O.cosmo"
100 >>> datapoint = cosmosac.CosmoSacMixtureDatapoint(
101 ... cosmo_files=[solute_path, solvent_path],
102 ... mole_fractions=[0.2, 0.8],
103 ... temperature=298.15,
104 ... targets=[-0.2, 0.02],
105 ... model=model,
106 ... )
107 >>> single_inputs = datapoint.get_inputs()
108 >>> batched_inputs = tuple(x.unsqueeze(0) for x in single_inputs)
109 >>> preds = module(batched_inputs)
110 >>> preds.shape
111 torch.Size([1, 2])
112 """
114 def __init__( # noqa: PLR0913
115 self,
116 num_segment_types: int,
117 temperature_exponents: Sequence[int],
118 area_per_segment: float,
119 reference_temperature: float = 298.15,
120 max_iter: int = 100,
121 learning_rate: float = 1e-3,
122 weight_decay: float = 0.0,
123 normalize_targets: bool = False,
124 loss_function: str = "mse_loss",
125 initialization: Sequence[NDArray[np.float64]] | int = 42,
126 ) -> None:
127 super().__init__()
129 if num_segment_types <= 0:
130 raise ValueError("num_segment_types must be a positive integer")
131 if len(temperature_exponents) == 0:
132 raise ValueError("temperature_exponents must not be empty")
133 if area_per_segment <= 0.0:
134 raise ValueError("area_per_segment must be positive")
135 if reference_temperature <= 0.0:
136 raise ValueError("reference_temperature must be positive")
137 if max_iter <= 0:
138 raise ValueError("max_iter must be a positive integer")
139 if learning_rate <= 0.0:
140 raise ValueError("learning_rate must be positive")
141 if weight_decay < 0.0:
142 raise ValueError("weight_decay must be non-negative")
143 loss_callable = getattr(F, loss_function, None)
144 if not is_loss_function(loss_callable):
145 raise ValueError(f"Unsupported loss_function '{loss_function}'.")
147 self.save_hyperparameters(ignore=["initialization"])
148 self.normalize_targets = normalize_targets
149 self.learning_rate = learning_rate
150 self.weight_decay = weight_decay
151 self.loss_function = loss_callable
153 initial_matrices = self._build_initial_matrices(
154 initialization=initialization,
155 num_segment_types=num_segment_types,
156 num_matrices=len(temperature_exponents),
157 )
159 self.cosmo_layer = CosmoLayer(
160 interaction_matrices=initial_matrices,
161 exponents=temperature_exponents,
162 area_per_segment=area_per_segment,
163 reference_temperature=reference_temperature,
164 max_iter=max_iter,
165 learn_matrices=True,
166 )
168 self.test_mae = MeanAbsoluteError()
169 self.test_rmse = MeanSquaredError(squared=False)
170 self.test_r2 = R2Score()
172 self.register_buffer("target_mean", torch.tensor(0.0))
173 self.register_buffer("target_std", torch.tensor(1.0))
175 @staticmethod
176 def _build_initial_matrices(
177 initialization: Sequence[NDArray[np.float64]] | int,
178 num_segment_types: int,
179 num_matrices: int,
180 ) -> list[NDArray[np.float64]]:
181 """Create and validate the initial interaction matrices."""
182 if isinstance(initialization, int):
183 rng = np.random.default_rng(initialization)
184 return [
185 rng.normal(size=(num_segment_types, num_segment_types))
186 for _ in range(num_matrices)
187 ]
189 matrices = [np.asarray(matrix, dtype=np.float64) for matrix in initialization]
191 if len(matrices) != num_matrices:
192 raise ValueError(
193 "initialization must contain exactly one matrix per temperature "
194 f"exponent: expected {num_matrices}, got {len(matrices)}"
195 )
197 expected_shape = (num_segment_types, num_segment_types)
198 for index, matrix in enumerate(matrices):
199 if matrix.shape != expected_shape:
200 raise ValueError(
201 "Each initialization matrix must have shape "
202 f"{expected_shape}; matrix {index} has shape {matrix.shape}"
203 )
204 if not np.isfinite(matrix).all():
205 raise ValueError(
206 f"Initialization matrix {index} contains non-finite values"
207 )
209 return matrices
211 @staticmethod
212 def _infer_batch_size(predictions: torch.Tensor, targets: torch.Tensor) -> int:
213 """Infer the minibatch size from prediction and target tensors."""
214 if predictions.ndim == 0 or targets.ndim == 0:
215 raise ValueError(
216 "Predictions and targets must be batched tensors with a leading "
217 "batch dimension"
218 )
219 if predictions.shape != targets.shape:
220 raise ValueError(
221 "Predictions and targets must have the same shape; "
222 f"got {predictions.shape} and {targets.shape}"
223 )
224 return int(targets.shape[0])
226 @torch.no_grad()
227 def _compute_target_statistics(self) -> None:
228 trainer = self.trainer
229 if hasattr(trainer, "datamodule") and trainer.datamodule is not None:
230 dataloader = trainer.datamodule.train_dataloader()
231 else:
232 dataloader = trainer.train_dataloader
234 if dataloader is None:
235 raise ValueError(
236 "Training dataloader is unavailable; cannot normalize targets"
237 )
239 count = torch.tensor(0.0)
240 target_sum: torch.Tensor | None = None
241 target_sumsq: torch.Tensor | None = None
243 for batch in dataloader:
244 _, targets = batch
245 targets = targets.detach()
247 batch_count = torch.tensor(float(targets.shape[0]), device=targets.device)
248 batch_sum = targets.sum(dim=0)
249 batch_sumsq = (targets**2).sum(dim=0)
251 if target_sum is None:
252 count = count.to(targets.device)
253 target_sum = torch.zeros_like(batch_sum)
254 target_sumsq = torch.zeros_like(batch_sumsq)
256 count = count + batch_count
257 target_sum = target_sum + batch_sum
258 target_sumsq = target_sumsq + batch_sumsq
260 if target_sum is None or count.item() == 0:
261 raise ValueError("Training dataloader is empty; cannot normalize targets")
263 if td.is_available() and td.is_initialized():
264 td.all_reduce(count, op=td.ReduceOp.SUM)
265 td.all_reduce(target_sum, op=td.ReduceOp.SUM)
266 td.all_reduce(target_sumsq, op=td.ReduceOp.SUM)
268 if target_sum is None or target_sumsq is None:
269 raise ValueError("Training dataloader is empty; cannot normalize targets")
271 mean = target_sum / count
272 variance = torch.clamp(target_sumsq / count - mean**2, min=0.0)
273 std = torch.sqrt(variance + EPSILON)
275 self.target_mean = mean.to(self.device)
276 self.target_std = std.to(self.device)
278 def forward(self, inputs: InputsType) -> torch.Tensor:
279 """Compute predictions for a minibatch of datapoints.
281 Parameters
282 ----------
283 inputs : InputsType
284 Batched input tuple ``(temperature, mole_fractions, areas, volumes,
285 probabilities)``. All tensors must be batch-first and represent the
286 same minibatch of size ``b``.
288 Returns
289 -------
290 torch.Tensor
291 Batched predictions with leading dimension ``b``.
292 """
293 temperature, mole_fractions, areas, volumes, probabilities = inputs
294 log_gamma: torch.Tensor = self.cosmo_layer(
295 temperature, mole_fractions, areas, volumes, probabilities
296 )
297 return self.predict_from_log_gamma(temperature, mole_fractions, log_gamma)
299 def predict_from_log_gamma(
300 self,
301 T: torch.Tensor,
302 x: torch.Tensor,
303 log_gamma: torch.Tensor,
304 ) -> torch.Tensor:
305 """Convert log-activity coefficients to final predictions.
307 Parameters
308 ----------
309 T : torch.Tensor
310 Temperature in the same units as the reference temperature.
311 Shape: (...,).
312 x : torch.Tensor
313 Mole fractions of the components. Must sum to 1.
314 Shape: (..., num_components).
315 log_gamma : torch.Tensor
316 Logarithms of the activity coefficients.
317 Shape: (..., num_components).
319 Returns
320 -------
321 torch.Tensor
322 Final predictions.
323 """
324 return log_gamma
326 def configure_optimizers(self) -> torch.optim.Optimizer:
327 """Configure the optimizer used during training.
329 Returns
330 -------
331 torch.optim.Optimizer
332 Adam optimizer over all module parameters.
333 """
334 return torch.optim.Adam(
335 self.parameters(),
336 lr=self.learning_rate,
337 weight_decay=self.weight_decay,
338 )
340 def on_fit_start(self) -> None:
341 if self.normalize_targets:
342 self._compute_target_statistics()
344 def training_step(
345 self, batch: tuple[InputsType, torch.Tensor], batch_idx: int
346 ) -> torch.Tensor:
347 """Run one training step on a minibatch.
349 Parameters
350 ----------
351 batch : tuple[InputsType, torch.Tensor]
352 Batched inputs and batched ground-truth targets. Targets must have
353 the same shape as the model predictions, with leading dimension
354 equal to the minibatch size.
355 batch_idx : int
356 Index of the current batch.
358 Returns
359 -------
360 torch.Tensor
361 Training loss for the batch.
362 """
363 inputs, targets = batch
364 predictions = self(inputs)
365 batch_size = self._infer_batch_size(predictions, targets)
366 if self.normalize_targets:
367 target_mean = self.target_mean
368 target_std = self.target_std
369 targets = (targets - target_mean) / target_std
370 predictions = (predictions - target_mean) / target_std
371 loss: torch.Tensor = self.loss_function(predictions, targets)
372 self.log(
373 "train_loss",
374 loss,
375 on_step=False,
376 on_epoch=True,
377 batch_size=batch_size,
378 )
379 return loss
381 def validation_step(
382 self, batch: tuple[InputsType, torch.Tensor], batch_idx: int
383 ) -> torch.Tensor:
384 """Run one validation step on a minibatch.
386 Parameters
387 ----------
388 batch : tuple[InputsType, torch.Tensor]
389 Batched inputs and batched ground-truth targets. Targets must have
390 the same shape as the model predictions, with leading dimension
391 equal to the minibatch size.
392 batch_idx : int
393 Index of the current batch.
395 Returns
396 -------
397 torch.Tensor
398 Validation loss for the batch.
399 """
400 inputs, targets = batch
401 predictions = self(inputs)
402 batch_size = self._infer_batch_size(predictions, targets)
403 if self.normalize_targets:
404 target_mean = self.target_mean
405 target_std = self.target_std
406 targets = (targets - target_mean) / target_std
407 predictions = (predictions - target_mean) / target_std
408 loss: torch.Tensor = self.loss_function(predictions, targets)
409 self.log(
410 "val_loss",
411 loss,
412 on_step=False,
413 on_epoch=True,
414 batch_size=batch_size,
415 prog_bar=True,
416 )
417 return loss
419 def test_step(
420 self, batch: tuple[InputsType, torch.Tensor], batch_idx: int
421 ) -> torch.Tensor:
422 """Run one test step on a minibatch and update regression metrics.
424 Parameters
425 ----------
426 batch : tuple[InputsType, torch.Tensor]
427 Batched inputs and batched ground-truth targets. Targets must have
428 the same shape as the model predictions, with leading dimension
429 equal to the minibatch size.
430 batch_idx : int
431 Index of the current batch.
433 Returns
434 -------
435 torch.Tensor
436 Test loss for the batch.
437 """
438 inputs, targets = batch
439 predictions = self(inputs)
440 batch_size = self._infer_batch_size(predictions, targets)
441 loss_predictions = predictions
442 loss_targets = targets
443 if self.normalize_targets:
444 target_mean = self.target_mean
445 target_std = self.target_std
446 loss_targets = (targets - target_mean) / target_std
447 loss_predictions = (predictions - target_mean) / target_std
448 loss: torch.Tensor = self.loss_function(loss_predictions, loss_targets)
450 self.test_mae.update(predictions, targets)
451 self.test_rmse.update(predictions, targets)
452 self.test_r2.update(predictions, targets)
454 self.log_dict(
455 {
456 "test_loss": loss,
457 "test_mae": self.test_mae,
458 "test_rmse": self.test_rmse,
459 "test_r2": self.test_r2,
460 },
461 on_step=False,
462 on_epoch=True,
463 batch_size=batch_size,
464 )
465 return loss