Coverage for cosmolayer / cosmolightning.py: 79%

156 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 14:25 +0000

1""" 

2.. module:: cosmolayer.cosmolightning 

3 :synopsis: PyTorch Lightning module for batched CosmoLayer training. 

4""" 

5 

6from __future__ import annotations 

7 

8from collections.abc import Sequence 

9 

10import numpy as np 

11import torch 

12from lightning import pytorch as pl 

13from numpy.typing import NDArray 

14from torch import distributed as td 

15from torch.nn import functional as F 

16from torchmetrics import MeanAbsoluteError, MeanSquaredError, R2Score 

17 

18from .cosmodata import InputsType 

19from .cosmolayer import CosmoLayer 

20from .utils import is_loss_function 

21 

22EPSILON = 1e-8 

23 

24 

25class LogGammaLightningModule(pl.LightningModule): 

26 """PyTorch Lightning module for batched training of a learnable 

27 :class:`~cosmolayer.CosmoLayer`. 

28 

29 This class is the canonical high-level training interface for CosmoLayer. 

30 It constructs an internal :class:`~cosmolayer.CosmoLayer` with learnable 

31 interaction matrices and defines the optimization, training, validation, 

32 test, and prediction logic. 

33 

34 The targets are the log-activity coefficients of the components. In order to 

35 handle other tasks, the user must subclass :class:`LogGammaLightningModule` and 

36 override the :meth:`~LogGammaLightningModule.predict_from_log_gamma` method. For 

37 instance: 

38 

39 .. code-block:: python 

40 

41 from scipy.constants import R 

42 

43 class ExcessGibbsLightningModule(LogGammaLightningModule): 

44 def predict_from_log_gamma(self, T, x, log_gamma): 

45 return (R * T * (x * log_gamma).sum(dim=-1)).unsqueeze(-1) 

46 

47 The module is batch-first throughout. All inputs must represent a minibatch 

48 of ``b`` datapoints, and the returned predictions must have leading 

49 dimension ``b``. Targets must have the same shape as the predictions. 

50 

51 Parameters 

52 ---------- 

53 num_segment_types : int 

54 Number of COSMO segment types. 

55 temperature_exponents : tuple[int, ...] 

56 Exponents defining the temperature dependence of the interaction 

57 matrices. 

58 area_per_segment : float 

59 Area associated with one segment. 

60 reference_temperature : float, optional 

61 Reference temperature used by :class:`CosmoLayer`. 

62 Default is ``298.15``. 

63 max_iter : int, optional 

64 Maximum number of internal fixed-point or iterative solver steps used 

65 by :class:`CosmoLayer`. Default is ``100``. 

66 learning_rate : float, optional 

67 Learning rate for the Adam optimizer. Default is ``1e-3``. 

68 weight_decay : float, optional 

69 Weight decay for the Adam optimizer. Default is ``0.0``. 

70 loss_function : str, optional 

71 Loss function used in training, validation, and test steps. Must be a 

72 valid loss function from :mod:`torch.nn.functional`. 

73 Default is ``"mse_loss"``. 

74 initialization : Sequence[NDArray[np.float64]] | int, optional 

75 Initialization for the learnable interaction matrices. 

76 

77 - If an ``int`` is provided, it is interpreted as the random seed used 

78 to sample one matrix per temperature exponent from a standard normal 

79 distribution. 

80 - If a sequence of NumPy arrays is provided, it must contain exactly 

81 one array per temperature exponent, and each array must have shape 

82 ``(num_segment_types, num_segment_types)``. 

83 

84 Default is ``42``. 

85 

86 Examples 

87 -------- 

88 >>> import torch 

89 >>> from importlib.resources import files 

90 >>> import cosmolayer as cl 

91 >>> from cosmolayer import cosmosac 

92 >>> model = cosmosac.CosmoSac2010Model 

93 >>> module = LogGammaLightningModule( 

94 ... num_segment_types=model.num_segment_types, 

95 ... temperature_exponents=model.temperature_exponents, 

96 ... area_per_segment=model.area_per_segment, 

97 ... ) 

98 >>> solute_path = files("cosmolayer.data") / "NCCO.cosmo" 

99 >>> solvent_path = files("cosmolayer.data") / "O.cosmo" 

100 >>> datapoint = cosmosac.CosmoSacMixtureDatapoint( 

101 ... cosmo_files=[solute_path, solvent_path], 

102 ... mole_fractions=[0.2, 0.8], 

103 ... temperature=298.15, 

104 ... targets=[-0.2, 0.02], 

105 ... model=model, 

106 ... ) 

107 >>> single_inputs = datapoint.get_inputs() 

108 >>> batched_inputs = tuple(x.unsqueeze(0) for x in single_inputs) 

109 >>> preds = module(batched_inputs) 

110 >>> preds.shape 

111 torch.Size([1, 2]) 

112 """ 

113 

114 def __init__( # noqa: PLR0913 

115 self, 

116 num_segment_types: int, 

117 temperature_exponents: Sequence[int], 

118 area_per_segment: float, 

119 reference_temperature: float = 298.15, 

120 max_iter: int = 100, 

121 learning_rate: float = 1e-3, 

122 weight_decay: float = 0.0, 

123 normalize_targets: bool = False, 

124 loss_function: str = "mse_loss", 

125 initialization: Sequence[NDArray[np.float64]] | int = 42, 

126 ) -> None: 

127 super().__init__() 

128 

129 if num_segment_types <= 0: 

130 raise ValueError("num_segment_types must be a positive integer") 

131 if len(temperature_exponents) == 0: 

132 raise ValueError("temperature_exponents must not be empty") 

133 if area_per_segment <= 0.0: 

134 raise ValueError("area_per_segment must be positive") 

135 if reference_temperature <= 0.0: 

136 raise ValueError("reference_temperature must be positive") 

137 if max_iter <= 0: 

138 raise ValueError("max_iter must be a positive integer") 

139 if learning_rate <= 0.0: 

140 raise ValueError("learning_rate must be positive") 

141 if weight_decay < 0.0: 

142 raise ValueError("weight_decay must be non-negative") 

143 loss_callable = getattr(F, loss_function, None) 

144 if not is_loss_function(loss_callable): 

145 raise ValueError(f"Unsupported loss_function '{loss_function}'.") 

146 

147 self.save_hyperparameters(ignore=["initialization"]) 

148 self.normalize_targets = normalize_targets 

149 self.learning_rate = learning_rate 

150 self.weight_decay = weight_decay 

151 self.loss_function = loss_callable 

152 

153 initial_matrices = self._build_initial_matrices( 

154 initialization=initialization, 

155 num_segment_types=num_segment_types, 

156 num_matrices=len(temperature_exponents), 

157 ) 

158 

159 self.cosmo_layer = CosmoLayer( 

160 interaction_matrices=initial_matrices, 

161 exponents=temperature_exponents, 

162 area_per_segment=area_per_segment, 

163 reference_temperature=reference_temperature, 

164 max_iter=max_iter, 

165 learn_matrices=True, 

166 ) 

167 

168 self.test_mae = MeanAbsoluteError() 

169 self.test_rmse = MeanSquaredError(squared=False) 

170 self.test_r2 = R2Score() 

171 

172 self.register_buffer("target_mean", torch.tensor(0.0)) 

173 self.register_buffer("target_std", torch.tensor(1.0)) 

174 

175 @staticmethod 

176 def _build_initial_matrices( 

177 initialization: Sequence[NDArray[np.float64]] | int, 

178 num_segment_types: int, 

179 num_matrices: int, 

180 ) -> list[NDArray[np.float64]]: 

181 """Create and validate the initial interaction matrices.""" 

182 if isinstance(initialization, int): 

183 rng = np.random.default_rng(initialization) 

184 return [ 

185 rng.normal(size=(num_segment_types, num_segment_types)) 

186 for _ in range(num_matrices) 

187 ] 

188 

189 matrices = [np.asarray(matrix, dtype=np.float64) for matrix in initialization] 

190 

191 if len(matrices) != num_matrices: 

192 raise ValueError( 

193 "initialization must contain exactly one matrix per temperature " 

194 f"exponent: expected {num_matrices}, got {len(matrices)}" 

195 ) 

196 

197 expected_shape = (num_segment_types, num_segment_types) 

198 for index, matrix in enumerate(matrices): 

199 if matrix.shape != expected_shape: 

200 raise ValueError( 

201 "Each initialization matrix must have shape " 

202 f"{expected_shape}; matrix {index} has shape {matrix.shape}" 

203 ) 

204 if not np.isfinite(matrix).all(): 

205 raise ValueError( 

206 f"Initialization matrix {index} contains non-finite values" 

207 ) 

208 

209 return matrices 

210 

211 @staticmethod 

212 def _infer_batch_size(predictions: torch.Tensor, targets: torch.Tensor) -> int: 

213 """Infer the minibatch size from prediction and target tensors.""" 

214 if predictions.ndim == 0 or targets.ndim == 0: 

215 raise ValueError( 

216 "Predictions and targets must be batched tensors with a leading " 

217 "batch dimension" 

218 ) 

219 if predictions.shape != targets.shape: 

220 raise ValueError( 

221 "Predictions and targets must have the same shape; " 

222 f"got {predictions.shape} and {targets.shape}" 

223 ) 

224 return int(targets.shape[0]) 

225 

226 @torch.no_grad() 

227 def _compute_target_statistics(self) -> None: 

228 trainer = self.trainer 

229 if hasattr(trainer, "datamodule") and trainer.datamodule is not None: 

230 dataloader = trainer.datamodule.train_dataloader() 

231 else: 

232 dataloader = trainer.train_dataloader 

233 

234 if dataloader is None: 

235 raise ValueError( 

236 "Training dataloader is unavailable; cannot normalize targets" 

237 ) 

238 

239 count = torch.tensor(0.0) 

240 target_sum: torch.Tensor | None = None 

241 target_sumsq: torch.Tensor | None = None 

242 

243 for batch in dataloader: 

244 _, targets = batch 

245 targets = targets.detach() 

246 

247 batch_count = torch.tensor(float(targets.shape[0]), device=targets.device) 

248 batch_sum = targets.sum(dim=0) 

249 batch_sumsq = (targets**2).sum(dim=0) 

250 

251 if target_sum is None: 

252 count = count.to(targets.device) 

253 target_sum = torch.zeros_like(batch_sum) 

254 target_sumsq = torch.zeros_like(batch_sumsq) 

255 

256 count = count + batch_count 

257 target_sum = target_sum + batch_sum 

258 target_sumsq = target_sumsq + batch_sumsq 

259 

260 if target_sum is None or count.item() == 0: 

261 raise ValueError("Training dataloader is empty; cannot normalize targets") 

262 

263 if td.is_available() and td.is_initialized(): 

264 td.all_reduce(count, op=td.ReduceOp.SUM) 

265 td.all_reduce(target_sum, op=td.ReduceOp.SUM) 

266 td.all_reduce(target_sumsq, op=td.ReduceOp.SUM) 

267 

268 if target_sum is None or target_sumsq is None: 

269 raise ValueError("Training dataloader is empty; cannot normalize targets") 

270 

271 mean = target_sum / count 

272 variance = torch.clamp(target_sumsq / count - mean**2, min=0.0) 

273 std = torch.sqrt(variance + EPSILON) 

274 

275 self.target_mean = mean.to(self.device) 

276 self.target_std = std.to(self.device) 

277 

278 def forward(self, inputs: InputsType) -> torch.Tensor: 

279 """Compute predictions for a minibatch of datapoints. 

280 

281 Parameters 

282 ---------- 

283 inputs : InputsType 

284 Batched input tuple ``(temperature, mole_fractions, areas, volumes, 

285 probabilities)``. All tensors must be batch-first and represent the 

286 same minibatch of size ``b``. 

287 

288 Returns 

289 ------- 

290 torch.Tensor 

291 Batched predictions with leading dimension ``b``. 

292 """ 

293 temperature, mole_fractions, areas, volumes, probabilities = inputs 

294 log_gamma: torch.Tensor = self.cosmo_layer( 

295 temperature, mole_fractions, areas, volumes, probabilities 

296 ) 

297 return self.predict_from_log_gamma(temperature, mole_fractions, log_gamma) 

298 

299 def predict_from_log_gamma( 

300 self, 

301 T: torch.Tensor, 

302 x: torch.Tensor, 

303 log_gamma: torch.Tensor, 

304 ) -> torch.Tensor: 

305 """Convert log-activity coefficients to final predictions. 

306 

307 Parameters 

308 ---------- 

309 T : torch.Tensor 

310 Temperature in the same units as the reference temperature. 

311 Shape: (...,). 

312 x : torch.Tensor 

313 Mole fractions of the components. Must sum to 1. 

314 Shape: (..., num_components). 

315 log_gamma : torch.Tensor 

316 Logarithms of the activity coefficients. 

317 Shape: (..., num_components). 

318 

319 Returns 

320 ------- 

321 torch.Tensor 

322 Final predictions. 

323 """ 

324 return log_gamma 

325 

326 def configure_optimizers(self) -> torch.optim.Optimizer: 

327 """Configure the optimizer used during training. 

328 

329 Returns 

330 ------- 

331 torch.optim.Optimizer 

332 Adam optimizer over all module parameters. 

333 """ 

334 return torch.optim.Adam( 

335 self.parameters(), 

336 lr=self.learning_rate, 

337 weight_decay=self.weight_decay, 

338 ) 

339 

340 def on_fit_start(self) -> None: 

341 if self.normalize_targets: 

342 self._compute_target_statistics() 

343 

344 def training_step( 

345 self, batch: tuple[InputsType, torch.Tensor], batch_idx: int 

346 ) -> torch.Tensor: 

347 """Run one training step on a minibatch. 

348 

349 Parameters 

350 ---------- 

351 batch : tuple[InputsType, torch.Tensor] 

352 Batched inputs and batched ground-truth targets. Targets must have 

353 the same shape as the model predictions, with leading dimension 

354 equal to the minibatch size. 

355 batch_idx : int 

356 Index of the current batch. 

357 

358 Returns 

359 ------- 

360 torch.Tensor 

361 Training loss for the batch. 

362 """ 

363 inputs, targets = batch 

364 predictions = self(inputs) 

365 batch_size = self._infer_batch_size(predictions, targets) 

366 if self.normalize_targets: 

367 target_mean = self.target_mean 

368 target_std = self.target_std 

369 targets = (targets - target_mean) / target_std 

370 predictions = (predictions - target_mean) / target_std 

371 loss: torch.Tensor = self.loss_function(predictions, targets) 

372 self.log( 

373 "train_loss", 

374 loss, 

375 on_step=False, 

376 on_epoch=True, 

377 batch_size=batch_size, 

378 ) 

379 return loss 

380 

381 def validation_step( 

382 self, batch: tuple[InputsType, torch.Tensor], batch_idx: int 

383 ) -> torch.Tensor: 

384 """Run one validation step on a minibatch. 

385 

386 Parameters 

387 ---------- 

388 batch : tuple[InputsType, torch.Tensor] 

389 Batched inputs and batched ground-truth targets. Targets must have 

390 the same shape as the model predictions, with leading dimension 

391 equal to the minibatch size. 

392 batch_idx : int 

393 Index of the current batch. 

394 

395 Returns 

396 ------- 

397 torch.Tensor 

398 Validation loss for the batch. 

399 """ 

400 inputs, targets = batch 

401 predictions = self(inputs) 

402 batch_size = self._infer_batch_size(predictions, targets) 

403 if self.normalize_targets: 

404 target_mean = self.target_mean 

405 target_std = self.target_std 

406 targets = (targets - target_mean) / target_std 

407 predictions = (predictions - target_mean) / target_std 

408 loss: torch.Tensor = self.loss_function(predictions, targets) 

409 self.log( 

410 "val_loss", 

411 loss, 

412 on_step=False, 

413 on_epoch=True, 

414 batch_size=batch_size, 

415 prog_bar=True, 

416 ) 

417 return loss 

418 

419 def test_step( 

420 self, batch: tuple[InputsType, torch.Tensor], batch_idx: int 

421 ) -> torch.Tensor: 

422 """Run one test step on a minibatch and update regression metrics. 

423 

424 Parameters 

425 ---------- 

426 batch : tuple[InputsType, torch.Tensor] 

427 Batched inputs and batched ground-truth targets. Targets must have 

428 the same shape as the model predictions, with leading dimension 

429 equal to the minibatch size. 

430 batch_idx : int 

431 Index of the current batch. 

432 

433 Returns 

434 ------- 

435 torch.Tensor 

436 Test loss for the batch. 

437 """ 

438 inputs, targets = batch 

439 predictions = self(inputs) 

440 batch_size = self._infer_batch_size(predictions, targets) 

441 loss_predictions = predictions 

442 loss_targets = targets 

443 if self.normalize_targets: 

444 target_mean = self.target_mean 

445 target_std = self.target_std 

446 loss_targets = (targets - target_mean) / target_std 

447 loss_predictions = (predictions - target_mean) / target_std 

448 loss: torch.Tensor = self.loss_function(loss_predictions, loss_targets) 

449 

450 self.test_mae.update(predictions, targets) 

451 self.test_rmse.update(predictions, targets) 

452 self.test_r2.update(predictions, targets) 

453 

454 self.log_dict( 

455 { 

456 "test_loss": loss, 

457 "test_mae": self.test_mae, 

458 "test_rmse": self.test_rmse, 

459 "test_r2": self.test_r2, 

460 }, 

461 on_step=False, 

462 on_epoch=True, 

463 batch_size=batch_size, 

464 ) 

465 return loss