MixtureTrainingDataset#
- class cosmolayer.MixtureTrainingDataset(mixtures, dtype=torch.float64)[source]#
Torch dataset wrapper for shape-compatible mixture datapoints.
- Parameters:
mixtures (Sequence[MixtureDatapoint]) – Datapoints to expose through the dataset interface. All datapoints must share the same structural shape.
dtype (torch.dtype) – Data type used when converting datapoints to tensors.
- Raises:
ValueError – If
mixturesis empty or contains incompatible datapoint shapes.
Examples
>>> from cosmolayer.cosmodata import MixtureTrainingDataset, MixtureDatapoint >>> from cosmolayer.cosmosac import CosmoSac2002Model >>> from cosmolayer.cosmosac.datapoint import CosmoSacMixtureDatapoint >>> from importlib.resources import files >>> data = files("cosmolayer.data") >>> cosmo_files = [data / "C=C(N)O.cosmo", data / "NCCO.cosmo"] >>> mole_fractions = [0.5, 0.5] >>> temperature = 298.15 >>> targets = [1.2] >>> dp = CosmoSacMixtureDatapoint( ... cosmo_files, ... mole_fractions, ... temperature, ... targets, ... CosmoSac2002Model, ... ) >>> dataset = MixtureTrainingDataset([dp], dtype=torch.float32) >>> len(dataset) 1 >>> inputs, targets = dataset[0] >>> len(inputs) 5 >>> len(targets) 1
Methods