MixtureTrainingDataset#

class cosmolayer.MixtureTrainingDataset(mixtures, dtype=torch.float64)[source]#

Torch dataset wrapper for shape-compatible mixture datapoints.

Parameters:
  • mixtures (Sequence[MixtureDatapoint]) – Datapoints to expose through the dataset interface. All datapoints must share the same structural shape.

  • dtype (torch.dtype) – Data type used when converting datapoints to tensors.

Raises:

ValueError – If mixtures is empty or contains incompatible datapoint shapes.

Examples

>>> from cosmolayer.cosmodata import MixtureTrainingDataset, MixtureDatapoint
>>> from cosmolayer.cosmosac import CosmoSac2002Model
>>> from cosmolayer.cosmosac.datapoint import CosmoSacMixtureDatapoint
>>> from importlib.resources import files
>>> data = files("cosmolayer.data")
>>> cosmo_files = [data / "C=C(N)O.cosmo", data / "NCCO.cosmo"]
>>> mole_fractions = [0.5, 0.5]
>>> temperature = 298.15
>>> targets = [1.2]
>>> dp = CosmoSacMixtureDatapoint(
...     cosmo_files,
...     mole_fractions,
...     temperature,
...     targets,
...     CosmoSac2002Model,
... )
>>> dataset = MixtureTrainingDataset([dp], dtype=torch.float32)
>>> len(dataset)
1
>>> inputs, targets = dataset[0]
>>> len(inputs)
5
>>> len(targets)
1

Methods

to_inference_dataset()[source]#

Convert the training dataset to an inference dataset.

Returns:

An inference dataset with the same mixtures and dtype.

Return type:

MixtureInferenceDataset