MixtureInferenceDataset#
- class cosmolayer.MixtureInferenceDataset(mixtures, dtype)[source]#
Torch dataset wrapper for shape-compatible mixture datapoints in inference.
- Parameters:
mixtures (Sequence[MixtureDatapoint]) – Datapoints to expose through the dataset interface. All datapoints must share the same input shape (number of components and segment types).
dtype (torch.dtype) – Data type used when converting datapoints to tensors.
- Raises:
ValueError – If
mixturesis empty or contains incompatible input shapes.
Examples
>>> from cosmolayer.cosmodata import MixtureInferenceDataset, MixtureDatapoint >>> dp = MixtureDatapoint( ... temperature=298.15, ... mole_fractions=np.array([0.5, 0.5]), ... areas=np.array([1.0, 2.0]), ... volumes=np.array([1.0, 2.0]), ... probabilities=np.array([[0.5, 0.5], [0.4, 0.6]]), ... targets=np.array([]), ... ) >>> dataset = MixtureInferenceDataset([dp], dtype=torch.float32) >>> inputs = dataset[0] >>> len(inputs) 5