Coverage for cosmolayer / cosmosolver.py: 95%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 14:25 +0000

1""" 

2.. module:: cosmolayer.cosmosolver 

3 :synopsis: Solves the self-consistent equation for the segment activity coefficients. 

4 

5.. functionauthor:: Charlles Abreu <craabreu@gmail.com> 

6""" 

7 

8from __future__ import annotations 

9 

10from typing import Any 

11 

12import torch 

13from torch.autograd.function import FunctionCtx, NestedIOFunction 

14 

15from .utils import log_matmul_exp 

16 

17NEWTON_STEP_TOLERANCE = {torch.float32: 1e-5, torch.float64: 1e-10} 

18NEWTON_RESIDUAL_TOLERANCE = {torch.float32: 1e-6, torch.float64: 1e-12} 

19 

20 

21class CosmoSolver(torch.autograd.Function): 

22 r"""COSMO self-consistent equation solver. 

23 

24 Solves the COSMO self-consistent equations for the logarithm of the activity 

25 coefficient vector, :math:`\ln \boldsymbol{\gamma}`, given the nonnegative 

26 probability distribution vector :math:`\mathbf{p}` and the reduced interaction 

27 energy matrix :math:`\mathbf{U}/(RT)`. 

28 

29 The self-consistent equations are: 

30 

31 .. math:: 

32 

33 \boldsymbol{\gamma}\circ \left( 

34 \mathbf{B} ({\mathbf p} \circ \boldsymbol{\gamma}) 

35 \right) = t \mathbf{1}, 

36 

37 where :math:`\mathbf{B} = \exp(-\mathbf{U}/(RT))` is the matrix of Boltzmann 

38 factors, :math:`t=\mathbf{1}^T \mathbf{p}` is the sum of the probabilities, and 

39 :math:`\circ` represents an elementwise product. 

40 

41 The solution satisfies 

42 :math:`\boldsymbol{\gamma}^\mathsf{T} \mathbf{M} \boldsymbol{\gamma} = t`, 

43 where :math:`\mathbf{M} = \mathbf{B} \circ (\mathbf{p}\mathbf{p}^T)`. 

44 

45 .. note:: 

46 Supports batching, i.e., if :math:`\mathbf{p}` and :math:`\mathbf{U}/(RT)` 

47 can have broadcastable leading dimensions, all computations are performed 

48 in a single vectorized operation. 

49 

50 Parameters 

51 ---------- 

52 p : torch.Tensor 

53 Segment-type probability distribution vector. Must be nonnegative. 

54 Shape: (..., num_segment_types). 

55 U_RT : torch.Tensor 

56 Reduced interaction energy matrix :math:`\mathbf{U}/(RT)`. 

57 Shape: (..., num_segment_types, num_segment_types). 

58 max_iter : int 

59 Maximum number of iterations. 

60 

61 Returns 

62 ------- 

63 log_gamma : torch.Tensor 

64 The logarithm of the segment activity coefficient vector. 

65 Shape: (..., num_segment_types). 

66 

67 Raises 

68 ------ 

69 RuntimeError 

70 If the fixed-point solver does not converge within ``max_iter`` iterations. 

71 

72 Examples 

73 -------- 

74 >>> import numpy as np 

75 >>> from cosmolayer.cosmosac import Component, CosmoSac2002Model 

76 >>> from importlib.resources import files 

77 >>> cosmo_strings = [ 

78 ... (files("cosmolayer.data") / f"{species}.cosmo").read_text() 

79 ... for species in ("C=C(N)O", "NCCO") 

80 ... ] 

81 >>> probabilities = [ 

82 ... CosmoSac2002Model.create_component(cosmo_string).probabilities 

83 ... for cosmo_string in cosmo_strings 

84 ... ] 

85 >>> p = torch.stack( 

86 ... [torch.tensor(prob, dtype=torch.float32) for prob in probabilities], 

87 ... ).requires_grad_(True) 

88 >>> U_RT = torch.tensor( 

89 ... CosmoSac2002Model.create_interaction_matrices(298.15)[0], 

90 ... dtype=torch.float32, 

91 ... requires_grad=True, 

92 ... ) 

93 >>> log_gamma, converged = CosmoSolver.apply(p, U_RT) 

94 >>> converged.all().item() 

95 True 

96 >>> log_gamma 

97 tensor([[-4.5...e+00, -4.0...e+00, ... -1.3...e+01], 

98 [-2.1...e+01, -1.9...e+01, ... -5.3...e+00]], grad_fn=<CosmoSolverBackward>) 

99 >>> loss = (2 * log_gamma).exp().sum() 

100 >>> loss.backward() 

101 >>> p.grad 

102 tensor([[ 2.1...e+02, 2.1...e+02, ... -7.4...e+05], 

103 [-6.6...e+02, -6.3...e+02, ... 7.4...e+02]]) 

104 """ 

105 

106 @staticmethod 

107 def _logspace_newton_solver( 

108 p: torch.Tensor, 

109 U_RT: torch.Tensor, 

110 max_iter: int, 

111 ) -> tuple[torch.Tensor, torch.Tensor]: 

112 step_tol = NEWTON_STEP_TOLERANCE[p.dtype] 

113 resid_tol = NEWTON_RESIDUAL_TOLERANCE[p.dtype] 

114 with torch.no_grad(): 

115 log_t = p.sum(dim=-1, keepdim=True).log().unsqueeze(-1) 

116 log_A = p.log().unsqueeze(-2) - U_RT 

117 Id = torch.eye(log_A.shape[-1], dtype=log_A.dtype, device=log_A.device) 

118 log_gamma = -torch.logsumexp(log_A, dim=-1, keepdim=True) + 0.5 * log_t 

119 log_A_gamma = log_matmul_exp(log_A, log_gamma) 

120 f = log_gamma + log_A_gamma - log_t 

121 for _ in range(max_iter): 

122 J = torch.exp(log_gamma.mT + log_A - log_A_gamma) + Id 

123 delta = torch.linalg.solve(J, -f) 

124 log_gamma += delta 

125 log_A_gamma = log_matmul_exp(log_A, log_gamma) 

126 f = log_gamma + log_A_gamma - log_t 

127 delta_norm = delta.abs().amax(dim=(-2, -1)) 

128 f_norm = f.abs().amax(dim=(-2, -1)) 

129 converged = (delta_norm < step_tol) & (f_norm < resid_tol) 

130 if bool(converged.all()): 

131 break 

132 return log_gamma, converged 

133 

134 @staticmethod 

135 def forward( 

136 ctx: FunctionCtx, 

137 p: torch.Tensor, 

138 U_RT: torch.Tensor, 

139 max_iter: int = 100, 

140 ) -> tuple[torch.Tensor, torch.Tensor]: 

141 ctx_any: Any = ctx 

142 ctx_any.p_shape = tuple(p.shape) 

143 ctx_any.u_shape = tuple(U_RT.shape) 

144 

145 if max_iter <= 0: 

146 raise ValueError("Maximum number of iterations must be positive") 

147 

148 invalid = (p < 0).any() | (p == 0).all(dim=-1).any() 

149 if bool(invalid): 

150 raise ValueError("Segment-type probabilities are invalid") 

151 

152 log_gamma, converged = CosmoSolver._logspace_newton_solver( 

153 p, U_RT, max_iter=max_iter 

154 ) 

155 ctx.save_for_backward(log_gamma, p, U_RT) 

156 

157 return log_gamma.squeeze(-1), converged 

158 

159 @staticmethod 

160 def backward( 

161 ctx: NestedIOFunction, 

162 grad_log_gamma: torch.Tensor | None, 

163 grad_converged: torch.Tensor | None, 

164 ) -> tuple[torch.Tensor | None, torch.Tensor | None, None]: 

165 if grad_log_gamma is None: 

166 return None, None, None 

167 

168 log_gamma, p, U_RT = ctx.saved_tensors 

169 

170 gamma = log_gamma.exp() 

171 B = torch.exp(-U_RT) 

172 

173 t = p.sum(dim=-1, keepdim=True) 

174 

175 # Rebuild log_A, log_A_gamma, and J (same as forward) 

176 log_A = p.log().unsqueeze(-2) - U_RT 

177 log_A_gamma = log_matmul_exp(log_A, log_gamma) 

178 Id = torch.eye(log_A.shape[-1], dtype=log_A.dtype, device=log_A.device) 

179 J = torch.exp(log_gamma.mT + log_A - log_A_gamma) + Id 

180 

181 # Solve (∂F/∂log_gamma)^T v = dL/dlog_gamma 

182 v = torch.linalg.solve(J.mT, grad_log_gamma.unsqueeze(-1)) 

183 

184 # r = v / (A @ gamma) 

185 r = v / log_A_gamma.exp() 

186 

187 # grad_p: -gamma * (B^T r) + (sum(v)/t) 

188 grad_p = -(gamma * (B.mT @ r)).squeeze(-1) + v.sum(dim=-2) / t 

189 

190 # grad_U_RT: r_i * B_ij * (p_j * gamma_j) 

191 pg = p * gamma.squeeze(-1) 

192 grad_U_RT = r * B * pg.unsqueeze(-2) 

193 

194 # Reduce to original shapes if broadcasting happened 

195 ctx_any: Any = ctx 

196 grad_p = grad_p.sum_to_size(ctx_any.p_shape) 

197 grad_U_RT = grad_U_RT.sum_to_size(ctx_any.u_shape) 

198 

199 return grad_p, grad_U_RT, None