Coverage for cosmolayer / cosmodata.py: 100%
72 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 14:25 +0000
1"""
2.. module:: cosmolayer.cosmodata
3 :synopsis: Data tensors for COSMO-SAC calculations.
5.. moduleauthor:: Charlles Abreu <craabreu@gmail.com>
6"""
8from collections.abc import Sequence
9from dataclasses import dataclass, field
10from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar
12import numpy as np
13import torch
15NumpyArray1D: TypeAlias = np.ndarray[tuple[int], np.dtype[np.number]]
16NumpyArray2D: TypeAlias = np.ndarray[tuple[int, int], np.dtype[np.number]]
18Tensor0D: TypeAlias = torch.Tensor
19Tensor1D: TypeAlias = torch.Tensor
20Tensor2D: TypeAlias = torch.Tensor
22InputsType: TypeAlias = tuple[Tensor0D, Tensor1D, Tensor1D, Tensor1D, Tensor2D]
24if TYPE_CHECKING:
25 _DatasetItemT = TypeVar("_DatasetItemT")
27 class _DatasetBase(Generic[_DatasetItemT]):
28 """Mypy-only base to avoid torch stub variability in CI."""
30else:
31 _DatasetBase = torch.utils.data.Dataset
34@dataclass
35class MixtureDatapoint:
36 """Base dataclass for a mixture datapoint.
38 Stores physical inputs (temperature, mole fractions, areas, volumes, and
39 segment-type probabilities) and optional training targets. Shape metadata
40 is computed and validated automatically on construction.
42 Parameters
43 ----------
44 temperature: float
45 Temperature.
46 mole_fractions: NumpyArray1D
47 Mole fractions.
48 Shape: ``(num_components,)``.
49 areas: NumpyArray1D
50 Segment surface areas per component.
51 Shape: ``(num_components,)``.
52 volumes: NumpyArray1D
53 Molar volumes per component.
54 Shape: ``(num_components,)``.
55 probabilities : NumpyArray2D
56 Sigma-profile probabilities.
57 Shape: ``(num_components, num_segment_types)``.
58 targets : NumpyArray1D
59 Training targets.
60 Shape: ``(num_targets,)``.
62 Attributes
63 ----------
64 num_components : int
65 Number of components.
66 num_segment_types : int
67 Number of segment-type probabilities.
68 num_targets : int
69 Number of training targets.
71 Raises
72 ------
73 ValueError
74 If array shapes are inconsistent.
75 """
77 temperature: float
78 mole_fractions: NumpyArray1D = field(repr=False)
79 areas: NumpyArray1D = field(repr=False)
80 volumes: NumpyArray1D = field(repr=False)
81 probabilities: NumpyArray2D = field(repr=False)
82 targets: NumpyArray1D = field(repr=False)
83 num_components: int = field(init=False)
84 num_segment_types: int = field(init=False)
85 num_targets: int = field(init=False)
87 def __post_init__(self) -> None:
88 """Validate array shapes and freeze stored numpy arrays.
90 Raises
91 ------
92 ValueError
93 If any stored array has an incompatible shape.
94 """
95 try:
96 assert self.probabilities.ndim == 2 # noqa: PLR2004
97 self.num_components, self.num_segment_types = self.probabilities.shape
98 assert self.mole_fractions.shape == (self.num_components,)
99 assert self.areas.shape == (self.num_components,)
100 assert self.volumes.shape == (self.num_components,)
101 assert self.targets.ndim == 1
102 except AssertionError as e:
103 raise ValueError("Invalid array shapes") from e
104 self.num_targets = len(self.targets)
105 for array in (
106 self.mole_fractions,
107 self.areas,
108 self.volumes,
109 self.probabilities,
110 self.targets,
111 ):
112 array.flags.writeable = False
114 @property
115 def shape(self) -> tuple[int, int, int]:
116 """Return the structural shape metadata for the datapoint.
118 Returns
119 -------
120 tuple[int, int, int]
121 Tuple containing the number of components, segment types,
122 and number of training targets.
123 """
124 return (
125 self.num_components,
126 self.num_segment_types,
127 self.num_targets,
128 )
130 def get_inputs(self, dtype: torch.dtype = torch.float64) -> InputsType:
131 """Convert physical inputs to torch tensors.
133 Parameters
134 ----------
135 dtype : torch.dtype
136 Data type used for all returned tensors. Default is ``torch.float64``.
138 Returns
139 -------
140 InputsType
141 Temperature, mole fractions, areas, volumes, and probabilities
142 as torch tensors.
143 """
144 return (
145 torch.tensor(self.temperature, dtype=dtype),
146 torch.tensor(self.mole_fractions, dtype=dtype),
147 torch.tensor(self.areas, dtype=dtype),
148 torch.tensor(self.volumes, dtype=dtype),
149 torch.tensor(self.probabilities, dtype=dtype),
150 )
152 def get_targets(self, dtype: torch.dtype = torch.float64) -> Tensor1D:
153 """Convert target arrays to torch tensors.
155 Parameters
156 ----------
157 dtype : torch.dtype
158 Data type used for all returned tensors. Default is ``torch.float64``.
160 Returns
161 -------
162 Tensor1D
163 Training targets as torch tensors.
164 """
165 return torch.tensor(self.targets, dtype=dtype)
168class MixtureInferenceDataset(_DatasetBase[InputsType]):
169 """Torch dataset wrapper for shape-compatible mixture datapoints in inference.
171 Parameters
172 ----------
173 mixtures : Sequence[MixtureDatapoint]
174 Datapoints to expose through the dataset interface. All datapoints
175 must share the same input shape (number of components and segment types).
176 dtype : torch.dtype
177 Data type used when converting datapoints to tensors.
179 Raises
180 ------
181 ValueError
182 If ``mixtures`` is empty or contains incompatible input shapes.
184 Examples
185 --------
186 >>> from cosmolayer.cosmodata import MixtureInferenceDataset, MixtureDatapoint
187 >>> dp = MixtureDatapoint(
188 ... temperature=298.15,
189 ... mole_fractions=np.array([0.5, 0.5]),
190 ... areas=np.array([1.0, 2.0]),
191 ... volumes=np.array([1.0, 2.0]),
192 ... probabilities=np.array([[0.5, 0.5], [0.4, 0.6]]),
193 ... targets=np.array([]),
194 ... )
195 >>> dataset = MixtureInferenceDataset([dp], dtype=torch.float32)
196 >>> inputs = dataset[0]
197 >>> len(inputs)
198 5
199 """
201 def __init__(
202 self,
203 mixtures: Sequence[MixtureDatapoint],
204 dtype: torch.dtype,
205 ):
206 if len(mixtures) == 0:
207 raise ValueError(
208 "MixtureInferenceDataset must contain at least one mixture"
209 )
210 input_shape = mixtures[0].shape[:2]
211 if any(mixture.shape[:2] != input_shape for mixture in mixtures[1:]):
212 raise ValueError("All mixtures must have the same input shape")
213 self._mixtures = mixtures
214 self._dtype = dtype
216 def __len__(self) -> int:
217 """Return the number of datapoints in the dataset."""
218 return len(self._mixtures)
220 def __getitem__(self, index: int) -> InputsType:
221 """Return one datapoint as input tensors.
223 Parameters
224 ----------
225 index : int
226 Position of the datapoint in the dataset.
228 Returns
229 -------
230 InputsType
231 Input tensors for the selected datapoint.
232 """
233 return self._mixtures[index].get_inputs(self._dtype)
236class MixtureTrainingDataset(_DatasetBase[tuple[InputsType, Tensor1D]]):
237 """Torch dataset wrapper for shape-compatible mixture datapoints.
239 Parameters
240 ----------
241 mixtures : Sequence[MixtureDatapoint]
242 Datapoints to expose through the dataset interface. All datapoints
243 must share the same structural shape.
244 dtype : torch.dtype
245 Data type used when converting datapoints to tensors.
247 Raises
248 ------
249 ValueError
250 If ``mixtures`` is empty or contains incompatible datapoint shapes.
253 Examples
254 --------
255 >>> from cosmolayer.cosmodata import MixtureTrainingDataset, MixtureDatapoint
256 >>> from cosmolayer.cosmosac import CosmoSac2002Model
257 >>> from cosmolayer.cosmosac.datapoint import CosmoSacMixtureDatapoint
258 >>> from importlib.resources import files
259 >>> data = files("cosmolayer.data")
260 >>> cosmo_files = [data / "C=C(N)O.cosmo", data / "NCCO.cosmo"]
261 >>> mole_fractions = [0.5, 0.5]
262 >>> temperature = 298.15
263 >>> targets = [1.2]
264 >>> dp = CosmoSacMixtureDatapoint(
265 ... cosmo_files,
266 ... mole_fractions,
267 ... temperature,
268 ... targets,
269 ... CosmoSac2002Model,
270 ... )
271 >>> dataset = MixtureTrainingDataset([dp], dtype=torch.float32)
272 >>> len(dataset)
273 1
274 >>> inputs, targets = dataset[0]
275 >>> len(inputs)
276 5
277 >>> len(targets)
278 1
279 """
281 def __init__(
282 self,
283 mixtures: Sequence[MixtureDatapoint],
284 dtype: torch.dtype = torch.float64,
285 ):
286 if len(mixtures) == 0:
287 raise ValueError("MixtureTrainingDataset must contain at least one mixture")
288 shape = mixtures[0].shape
289 if any(mixture.shape != shape for mixture in mixtures[1:]):
290 raise ValueError("All mixtures must have the same shape")
291 self._mixtures = mixtures
292 self._dtype = dtype
294 def __len__(self) -> int:
295 """Return the number of datapoints in the dataset."""
296 return len(self._mixtures)
298 def __getitem__(self, index: int) -> tuple[InputsType, Tensor1D]:
299 """Return one datapoint as input and target tensor tuples.
301 Parameters
302 ----------
303 index : int
304 Position of the datapoint in the dataset.
306 Returns
307 -------
308 tuple[InputsType, Tensor1D]
309 Input tensors and target tensors for the selected datapoint.
310 """
311 mixture = self._mixtures[index]
312 return mixture.get_inputs(self._dtype), mixture.get_targets(self._dtype)
314 def to_inference_dataset(self) -> MixtureInferenceDataset:
315 """Convert the training dataset to an inference dataset.
317 Returns
318 -------
319 MixtureInferenceDataset
320 An inference dataset with the same mixtures and dtype.
321 """
322 return MixtureInferenceDataset(self._mixtures, self._dtype)