MixtureInferenceDataset#

class cosmolayer.MixtureInferenceDataset(mixtures, dtype)[source]#

Torch dataset wrapper for shape-compatible mixture datapoints in inference.

Parameters:
  • mixtures (Sequence[MixtureDatapoint]) – Datapoints to expose through the dataset interface. All datapoints must share the same input shape (number of components and segment types).

  • dtype (torch.dtype) – Data type used when converting datapoints to tensors.

Raises:

ValueError – If mixtures is empty or contains incompatible input shapes.

Examples

>>> from cosmolayer.cosmodata import MixtureInferenceDataset, MixtureDatapoint
>>> dp = MixtureDatapoint(
...     temperature=298.15,
...     mole_fractions=np.array([0.5, 0.5]),
...     areas=np.array([1.0, 2.0]),
...     volumes=np.array([1.0, 2.0]),
...     probabilities=np.array([[0.5, 0.5], [0.4, 0.6]]),
...     targets=np.array([]),
... )
>>> dataset = MixtureInferenceDataset([dp], dtype=torch.float32)
>>> inputs = dataset[0]
>>> len(inputs)
5