Coverage for cosmolayer / cosmodata.py: 100%

72 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 14:25 +0000

1""" 

2.. module:: cosmolayer.cosmodata 

3 :synopsis: Data tensors for COSMO-SAC calculations. 

4 

5.. moduleauthor:: Charlles Abreu <craabreu@gmail.com> 

6""" 

7 

8from collections.abc import Sequence 

9from dataclasses import dataclass, field 

10from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar 

11 

12import numpy as np 

13import torch 

14 

15NumpyArray1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.number]] 

16NumpyArray2D: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.number]] 

17 

18Tensor0D: TypeAlias = torch.Tensor 

19Tensor1D: TypeAlias = torch.Tensor 

20Tensor2D: TypeAlias = torch.Tensor 

21 

22InputsType: TypeAlias = tuple[Tensor0D, Tensor1D, Tensor1D, Tensor1D, Tensor2D] 

23 

24if TYPE_CHECKING: 

25 _DatasetItemT = TypeVar("_DatasetItemT") 

26 

27 class _DatasetBase(Generic[_DatasetItemT]): 

28 """Mypy-only base to avoid torch stub variability in CI.""" 

29 

30else: 

31 _DatasetBase = torch.utils.data.Dataset 

32 

33 

34@dataclass 

35class MixtureDatapoint: 

36 """Base dataclass for a mixture datapoint. 

37 

38 Stores physical inputs (temperature, mole fractions, areas, volumes, and 

39 segment-type probabilities) and optional training targets. Shape metadata 

40 is computed and validated automatically on construction. 

41 

42 Parameters 

43 ---------- 

44 temperature: float 

45 Temperature. 

46 mole_fractions: NumpyArray1D 

47 Mole fractions. 

48 Shape: ``(num_components,)``. 

49 areas: NumpyArray1D 

50 Segment surface areas per component. 

51 Shape: ``(num_components,)``. 

52 volumes: NumpyArray1D 

53 Molar volumes per component. 

54 Shape: ``(num_components,)``. 

55 probabilities : NumpyArray2D 

56 Sigma-profile probabilities. 

57 Shape: ``(num_components, num_segment_types)``. 

58 targets : NumpyArray1D 

59 Training targets. 

60 Shape: ``(num_targets,)``. 

61 

62 Attributes 

63 ---------- 

64 num_components : int 

65 Number of components. 

66 num_segment_types : int 

67 Number of segment-type probabilities. 

68 num_targets : int 

69 Number of training targets. 

70 

71 Raises 

72 ------ 

73 ValueError 

74 If array shapes are inconsistent. 

75 """ 

76 

77 temperature: float 

78 mole_fractions: NumpyArray1D = field(repr=False) 

79 areas: NumpyArray1D = field(repr=False) 

80 volumes: NumpyArray1D = field(repr=False) 

81 probabilities: NumpyArray2D = field(repr=False) 

82 targets: NumpyArray1D = field(repr=False) 

83 num_components: int = field(init=False) 

84 num_segment_types: int = field(init=False) 

85 num_targets: int = field(init=False) 

86 

87 def __post_init__(self) -> None: 

88 """Validate array shapes and freeze stored numpy arrays. 

89 

90 Raises 

91 ------ 

92 ValueError 

93 If any stored array has an incompatible shape. 

94 """ 

95 try: 

96 assert self.probabilities.ndim == 2 # noqa: PLR2004 

97 self.num_components, self.num_segment_types = self.probabilities.shape 

98 assert self.mole_fractions.shape == (self.num_components,) 

99 assert self.areas.shape == (self.num_components,) 

100 assert self.volumes.shape == (self.num_components,) 

101 assert self.targets.ndim == 1 

102 except AssertionError as e: 

103 raise ValueError("Invalid array shapes") from e 

104 self.num_targets = len(self.targets) 

105 for array in ( 

106 self.mole_fractions, 

107 self.areas, 

108 self.volumes, 

109 self.probabilities, 

110 self.targets, 

111 ): 

112 array.flags.writeable = False 

113 

114 @property 

115 def shape(self) -> tuple[int, int, int]: 

116 """Return the structural shape metadata for the datapoint. 

117 

118 Returns 

119 ------- 

120 tuple[int, int, int] 

121 Tuple containing the number of components, segment types, 

122 and number of training targets. 

123 """ 

124 return ( 

125 self.num_components, 

126 self.num_segment_types, 

127 self.num_targets, 

128 ) 

129 

130 def get_inputs(self, dtype: torch.dtype = torch.float64) -> InputsType: 

131 """Convert physical inputs to torch tensors. 

132 

133 Parameters 

134 ---------- 

135 dtype : torch.dtype 

136 Data type used for all returned tensors. Default is ``torch.float64``. 

137 

138 Returns 

139 ------- 

140 InputsType 

141 Temperature, mole fractions, areas, volumes, and probabilities 

142 as torch tensors. 

143 """ 

144 return ( 

145 torch.tensor(self.temperature, dtype=dtype), 

146 torch.tensor(self.mole_fractions, dtype=dtype), 

147 torch.tensor(self.areas, dtype=dtype), 

148 torch.tensor(self.volumes, dtype=dtype), 

149 torch.tensor(self.probabilities, dtype=dtype), 

150 ) 

151 

152 def get_targets(self, dtype: torch.dtype = torch.float64) -> Tensor1D: 

153 """Convert target arrays to torch tensors. 

154 

155 Parameters 

156 ---------- 

157 dtype : torch.dtype 

158 Data type used for all returned tensors. Default is ``torch.float64``. 

159 

160 Returns 

161 ------- 

162 Tensor1D 

163 Training targets as torch tensors. 

164 """ 

165 return torch.tensor(self.targets, dtype=dtype) 

166 

167 

168class MixtureInferenceDataset(_DatasetBase[InputsType]): 

169 """Torch dataset wrapper for shape-compatible mixture datapoints in inference. 

170 

171 Parameters 

172 ---------- 

173 mixtures : Sequence[MixtureDatapoint] 

174 Datapoints to expose through the dataset interface. All datapoints 

175 must share the same input shape (number of components and segment types). 

176 dtype : torch.dtype 

177 Data type used when converting datapoints to tensors. 

178 

179 Raises 

180 ------ 

181 ValueError 

182 If ``mixtures`` is empty or contains incompatible input shapes. 

183 

184 Examples 

185 -------- 

186 >>> from cosmolayer.cosmodata import MixtureInferenceDataset, MixtureDatapoint 

187 >>> dp = MixtureDatapoint( 

188 ... temperature=298.15, 

189 ... mole_fractions=np.array([0.5, 0.5]), 

190 ... areas=np.array([1.0, 2.0]), 

191 ... volumes=np.array([1.0, 2.0]), 

192 ... probabilities=np.array([[0.5, 0.5], [0.4, 0.6]]), 

193 ... targets=np.array([]), 

194 ... ) 

195 >>> dataset = MixtureInferenceDataset([dp], dtype=torch.float32) 

196 >>> inputs = dataset[0] 

197 >>> len(inputs) 

198 5 

199 """ 

200 

201 def __init__( 

202 self, 

203 mixtures: Sequence[MixtureDatapoint], 

204 dtype: torch.dtype, 

205 ): 

206 if len(mixtures) == 0: 

207 raise ValueError( 

208 "MixtureInferenceDataset must contain at least one mixture" 

209 ) 

210 input_shape = mixtures[0].shape[:2] 

211 if any(mixture.shape[:2] != input_shape for mixture in mixtures[1:]): 

212 raise ValueError("All mixtures must have the same input shape") 

213 self._mixtures = mixtures 

214 self._dtype = dtype 

215 

216 def __len__(self) -> int: 

217 """Return the number of datapoints in the dataset.""" 

218 return len(self._mixtures) 

219 

220 def __getitem__(self, index: int) -> InputsType: 

221 """Return one datapoint as input tensors. 

222 

223 Parameters 

224 ---------- 

225 index : int 

226 Position of the datapoint in the dataset. 

227 

228 Returns 

229 ------- 

230 InputsType 

231 Input tensors for the selected datapoint. 

232 """ 

233 return self._mixtures[index].get_inputs(self._dtype) 

234 

235 

236class MixtureTrainingDataset(_DatasetBase[tuple[InputsType, Tensor1D]]): 

237 """Torch dataset wrapper for shape-compatible mixture datapoints. 

238 

239 Parameters 

240 ---------- 

241 mixtures : Sequence[MixtureDatapoint] 

242 Datapoints to expose through the dataset interface. All datapoints 

243 must share the same structural shape. 

244 dtype : torch.dtype 

245 Data type used when converting datapoints to tensors. 

246 

247 Raises 

248 ------ 

249 ValueError 

250 If ``mixtures`` is empty or contains incompatible datapoint shapes. 

251 

252 

253 Examples 

254 -------- 

255 >>> from cosmolayer.cosmodata import MixtureTrainingDataset, MixtureDatapoint 

256 >>> from cosmolayer.cosmosac import CosmoSac2002Model 

257 >>> from cosmolayer.cosmosac.datapoint import CosmoSacMixtureDatapoint 

258 >>> from importlib.resources import files 

259 >>> data = files("cosmolayer.data") 

260 >>> cosmo_files = [data / "C=C(N)O.cosmo", data / "NCCO.cosmo"] 

261 >>> mole_fractions = [0.5, 0.5] 

262 >>> temperature = 298.15 

263 >>> targets = [1.2] 

264 >>> dp = CosmoSacMixtureDatapoint( 

265 ... cosmo_files, 

266 ... mole_fractions, 

267 ... temperature, 

268 ... targets, 

269 ... CosmoSac2002Model, 

270 ... ) 

271 >>> dataset = MixtureTrainingDataset([dp], dtype=torch.float32) 

272 >>> len(dataset) 

273 1 

274 >>> inputs, targets = dataset[0] 

275 >>> len(inputs) 

276 5 

277 >>> len(targets) 

278 1 

279 """ 

280 

281 def __init__( 

282 self, 

283 mixtures: Sequence[MixtureDatapoint], 

284 dtype: torch.dtype = torch.float64, 

285 ): 

286 if len(mixtures) == 0: 

287 raise ValueError("MixtureTrainingDataset must contain at least one mixture") 

288 shape = mixtures[0].shape 

289 if any(mixture.shape != shape for mixture in mixtures[1:]): 

290 raise ValueError("All mixtures must have the same shape") 

291 self._mixtures = mixtures 

292 self._dtype = dtype 

293 

294 def __len__(self) -> int: 

295 """Return the number of datapoints in the dataset.""" 

296 return len(self._mixtures) 

297 

298 def __getitem__(self, index: int) -> tuple[InputsType, Tensor1D]: 

299 """Return one datapoint as input and target tensor tuples. 

300 

301 Parameters 

302 ---------- 

303 index : int 

304 Position of the datapoint in the dataset. 

305 

306 Returns 

307 ------- 

308 tuple[InputsType, Tensor1D] 

309 Input tensors and target tensors for the selected datapoint. 

310 """ 

311 mixture = self._mixtures[index] 

312 return mixture.get_inputs(self._dtype), mixture.get_targets(self._dtype) 

313 

314 def to_inference_dataset(self) -> MixtureInferenceDataset: 

315 """Convert the training dataset to an inference dataset. 

316 

317 Returns 

318 ------- 

319 MixtureInferenceDataset 

320 An inference dataset with the same mixtures and dtype. 

321 """ 

322 return MixtureInferenceDataset(self._mixtures, self._dtype)