Coverage for cosmolayer / cosmosolver.py: 95%
65 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
1"""
2.. module:: cosmolayer.cosmosolver
3 :synopsis: Solves the self-consistent equation for the segment activity coefficients.
5.. functionauthor:: Charlles Abreu <craabreu@gmail.com>
6"""
8from __future__ import annotations
10from typing import Any
12import torch
13from torch.autograd.function import FunctionCtx, NestedIOFunction
15from .utils import log_matmul_exp
17NEWTON_STEP_TOLERANCE = {torch.float32: 1e-5, torch.float64: 1e-10}
18NEWTON_RESIDUAL_TOLERANCE = {torch.float32: 1e-6, torch.float64: 1e-12}
21class CosmoSolver(torch.autograd.Function):
22 r"""COSMO self-consistent equation solver.
24 Solves the COSMO self-consistent equations for the logarithm of the activity
25 coefficient vector, :math:`\ln \boldsymbol{\gamma}`, given the nonnegative
26 probability distribution vector :math:`\mathbf{p}` and the reduced interaction
27 energy matrix :math:`\mathbf{U}/(RT)`.
29 The self-consistent equations are:
31 .. math::
33 \boldsymbol{\gamma}\circ \left(
34 \mathbf{B} ({\mathbf p} \circ \boldsymbol{\gamma})
35 \right) = t \mathbf{1},
37 where :math:`\mathbf{B} = \exp(-\mathbf{U}/(RT))` is the matrix of Boltzmann
38 factors, :math:`t=\mathbf{1}^T \mathbf{p}` is the sum of the probabilities, and
39 :math:`\circ` represents an elementwise product.
41 The solution satisfies
42 :math:`\boldsymbol{\gamma}^\mathsf{T} \mathbf{M} \boldsymbol{\gamma} = t`,
43 where :math:`\mathbf{M} = \mathbf{B} \circ (\mathbf{p}\mathbf{p}^T)`.
45 .. note::
46 Supports batching, i.e., if :math:`\mathbf{p}` and :math:`\mathbf{U}/(RT)`
47 can have broadcastable leading dimensions, all computations are performed
48 in a single vectorized operation.
50 Parameters
51 ----------
52 p : torch.Tensor
53 Segment-type probability distribution vector. Must be nonnegative.
54 Shape: (..., num_segment_types).
55 U_RT : torch.Tensor
56 Reduced interaction energy matrix :math:`\mathbf{U}/(RT)`.
57 Shape: (..., num_segment_types, num_segment_types).
58 max_iter : int
59 Maximum number of iterations.
61 Returns
62 -------
63 log_gamma : torch.Tensor
64 The logarithm of the segment activity coefficient vector.
65 Shape: (..., num_segment_types).
67 Raises
68 ------
69 RuntimeError
70 If the fixed-point solver does not converge within ``max_iter`` iterations.
72 Examples
73 --------
74 >>> import numpy as np
75 >>> from cosmolayer.cosmosac import Component, CosmoSac2002Model
76 >>> from importlib.resources import files
77 >>> cosmo_strings = [
78 ... (files("cosmolayer.data") / f"{species}.cosmo").read_text()
79 ... for species in ("C=C(N)O", "NCCO")
80 ... ]
81 >>> probabilities = [
82 ... CosmoSac2002Model.create_component(cosmo_string).probabilities
83 ... for cosmo_string in cosmo_strings
84 ... ]
85 >>> p = torch.stack(
86 ... [torch.tensor(prob, dtype=torch.float32) for prob in probabilities],
87 ... ).requires_grad_(True)
88 >>> U_RT = torch.tensor(
89 ... CosmoSac2002Model.create_interaction_matrices(298.15)[0],
90 ... dtype=torch.float32,
91 ... requires_grad=True,
92 ... )
93 >>> log_gamma, converged = CosmoSolver.apply(p, U_RT)
94 >>> converged.all().item()
95 True
96 >>> log_gamma
97 tensor([[-4.5...e+00, -4.0...e+00, ... -1.3...e+01],
98 [-2.1...e+01, -1.9...e+01, ... -5.3...e+00]], grad_fn=<CosmoSolverBackward>)
99 >>> loss = (2 * log_gamma).exp().sum()
100 >>> loss.backward()
101 >>> p.grad
102 tensor([[ 2.1...e+02, 2.1...e+02, ... -7.4...e+05],
103 [-6.6...e+02, -6.3...e+02, ... 7.4...e+02]])
104 """
106 @staticmethod
107 def _logspace_newton_solver(
108 p: torch.Tensor,
109 U_RT: torch.Tensor,
110 max_iter: int,
111 ) -> tuple[torch.Tensor, torch.Tensor]:
112 step_tol = NEWTON_STEP_TOLERANCE[p.dtype]
113 resid_tol = NEWTON_RESIDUAL_TOLERANCE[p.dtype]
114 with torch.no_grad():
115 log_t = p.sum(dim=-1, keepdim=True).log().unsqueeze(-1)
116 log_A = p.log().unsqueeze(-2) - U_RT
117 Id = torch.eye(log_A.shape[-1], dtype=log_A.dtype, device=log_A.device)
118 log_gamma = -torch.logsumexp(log_A, dim=-1, keepdim=True) + 0.5 * log_t
119 log_A_gamma = log_matmul_exp(log_A, log_gamma)
120 f = log_gamma + log_A_gamma - log_t
121 for _ in range(max_iter):
122 J = torch.exp(log_gamma.mT + log_A - log_A_gamma) + Id
123 delta = torch.linalg.solve(J, -f)
124 log_gamma += delta
125 log_A_gamma = log_matmul_exp(log_A, log_gamma)
126 f = log_gamma + log_A_gamma - log_t
127 delta_norm = delta.abs().amax(dim=(-2, -1))
128 f_norm = f.abs().amax(dim=(-2, -1))
129 converged = (delta_norm < step_tol) & (f_norm < resid_tol)
130 if bool(converged.all()):
131 break
132 return log_gamma, converged
134 @staticmethod
135 def forward(
136 ctx: FunctionCtx,
137 p: torch.Tensor,
138 U_RT: torch.Tensor,
139 max_iter: int = 100,
140 ) -> tuple[torch.Tensor, torch.Tensor]:
141 ctx_any: Any = ctx
142 ctx_any.p_shape = tuple(p.shape)
143 ctx_any.u_shape = tuple(U_RT.shape)
145 if max_iter <= 0:
146 raise ValueError("Maximum number of iterations must be positive")
148 invalid = (p < 0).any() | (p == 0).all(dim=-1).any()
149 if bool(invalid):
150 raise ValueError("Segment-type probabilities are invalid")
152 log_gamma, converged = CosmoSolver._logspace_newton_solver(
153 p, U_RT, max_iter=max_iter
154 )
155 ctx.save_for_backward(log_gamma, p, U_RT)
157 return log_gamma.squeeze(-1), converged
159 @staticmethod
160 def backward(
161 ctx: NestedIOFunction,
162 grad_log_gamma: torch.Tensor | None,
163 grad_converged: torch.Tensor | None,
164 ) -> tuple[torch.Tensor | None, torch.Tensor | None, None]:
165 if grad_log_gamma is None:
166 return None, None, None
168 log_gamma, p, U_RT = ctx.saved_tensors
170 gamma = log_gamma.exp()
171 B = torch.exp(-U_RT)
173 t = p.sum(dim=-1, keepdim=True)
175 # Rebuild log_A, log_A_gamma, and J (same as forward)
176 log_A = p.log().unsqueeze(-2) - U_RT
177 log_A_gamma = log_matmul_exp(log_A, log_gamma)
178 Id = torch.eye(log_A.shape[-1], dtype=log_A.dtype, device=log_A.device)
179 J = torch.exp(log_gamma.mT + log_A - log_A_gamma) + Id
181 # Solve (∂F/∂log_gamma)^T v = dL/dlog_gamma
182 v = torch.linalg.solve(J.mT, grad_log_gamma.unsqueeze(-1))
184 # r = v / (A @ gamma)
185 r = v / log_A_gamma.exp()
187 # grad_p: -gamma * (B^T r) + (sum(v)/t)
188 grad_p = -(gamma * (B.mT @ r)).squeeze(-1) + v.sum(dim=-2) / t
190 # grad_U_RT: r_i * B_ij * (p_j * gamma_j)
191 pg = p * gamma.squeeze(-1)
192 grad_U_RT = r * B * pg.unsqueeze(-2)
194 # Reduce to original shapes if broadcasting happened
195 ctx_any: Any = ctx
196 grad_p = grad_p.sum_to_size(ctx_any.p_shape)
197 grad_U_RT = grad_U_RT.sum_to_size(ctx_any.u_shape)
199 return grad_p, grad_U_RT, None