Datasets

GraphDatasets are collections of AtomicGraphs. We provide a base class, GraphDataset, together with several implementations. The most common way to get a dataset of graphs is to use load_atoms_dataset() or file_dataset().

Useful Datasets

graph_pes.data.load_atoms_dataset(
id,
cutoff,
n_train,
n_valid,
n_test=None,
split='random',
seed=42,
pre_transform=True,
property_map=None,
others_to_include=None,
)[source]

Load an dataset of ase.Atoms objects using load-atoms, convert them to AtomicGraph instances, and split into train and valid sets.

Parameters:
  • id (str | pathlib.Path) – The dataset identifier. Can be a load-atoms id, or a path to an ase-readable data file.

  • cutoff (float) – The cutoff radius for the neighbor list.

  • n_train (int) – The number of training structures.

  • n_valid (int) – The number of validation structures.

  • n_test (int | None) – The number of test structures. If None, no test set is created.

  • split (Literal['random', 'sequential']) – The split method. "random" shuffles the structures before choosing a non-overlapping split, while "sequential" takes the first n_train structures for training and the next n_valid structures for validation.

  • seed (int) – The random seed.

  • pre_transform (bool) – Whether to pre-calculate the neighbour lists for each structure.

  • root – The root directory

  • property_map (dict[str, PropertyKey] | None) – A mapping from properties as named on the atoms objects to graph-pes property keys, e.g. {"U0": "energy"}.

  • others_to_include (list[str] | None) – A list of properties to include in the graph.other field that are present as per-atom or per-structure properties on the ase.Atoms objects.

Returns:

A collection of training, validation, and optional test datasets.

Return type:

DatasetCollection

Examples

Load a subset of the QM9 dataset. Ensure that the U0 property is mapped to energy:

>>> load_atoms_dataset(
...     "QM9",
...     cutoff=5.0,
...     n_train=1_000,
...     n_valid=100,
...     n_test=100,
...     property_map={"U0": "energy"},
... )
graph_pes.data.file_dataset(
path,
cutoff,
n=None,
shuffle=True,
seed=42,
pre_transform=True,
property_map=None,
others_to_include=None,
)[source]

Load an ASE dataset from a file that is either:

  • any plain-text file that can be read by ase.io.read(), e.g. an .xyz file

  • a .db file containing a SQLite database of ase.Atoms objects that is readable as an ASE database. Under the hood, this uses the ASEDatabase class - see there for more details.

Parameters:
  • path (str | pathlib.Path) – The path to the file.

  • cutoff (float) – The cutoff radius for the neighbour list.

  • n (int | None) – The number of structures to load. If None, all structures are loaded.

  • shuffle (bool) – Whether to shuffle the structures.

  • seed (int) – The random seed used for shuffling.

  • pre_transform (bool) – Whether to pre-calculate the neighbour lists for each structure.

  • property_map (dict[str, PropertyKey] | None) – A mapping from properties as named on the atoms objects to graph-pes property keys, e.g. {"U0": "energy"}.

  • others_to_include (list[str] | None) – A list of properties to include in the graph.other field that are present as per-atom or per-structure properties on the ase.Atoms objects.

Returns:

The ASE dataset.

Return type:

ASEToGraphDataset

Example

Load a dataset from a file, ensuring that the energy property is mapped to U0:

>>> file_dataset(
...     "training_data.xyz",
...     cutoff=5.0,
...     property_map={"U0": "energy"},
... )

Base Classes

class graph_pes.data.GraphDataset[source]

Bases: Dataset, ABC

A dataset of AtomicGraph instances.

Parameters:

graphs (Sequence[AtomicGraph]) – The collection of AtomicGraph instances.

prepare_data()[source]

Make general preparations for loading the data for the dataset.

Called on rank-0 only: don’t set any state here. May be called multiple times.

setup()[source]

Set-up the data for this specific instance of the dataset.

Called on every process in the distributed setup. May be called multiple times.

property properties: list[Literal['local_energies', 'forces', 'energy', 'stress', 'virial']]

The properties that are available to train on with this dataset

class graph_pes.data.ASEToGraphDataset[source]

Bases: GraphDataset

A dataset that wraps a Sequence of ase.Atoms, and converts them to AtomicGraph instances.

Parameters:
  • structures (Sequence[ase.Atoms]) – The collection of ase.Atoms objects to convert to AtomicGraph instances.

  • cutoff (float) – The cutoff to use when creating neighbour indexes for the graphs.

  • pre_transform (bool) – Whether to precompute the the AtomicGraph objects, or only do so on-the-fly when the dataset is accessed. This pre-computations stores the graphs in memory, and so will be prohibitively expensive for large datasets.

  • property_mapping (Mapping[str, PropertyKey] | None) – A mapping from properties defined on the ase.Atoms objects to their appropriate names in graph-pes, see from_ase().

  • others_to_include (list[str] | None) – A list of properties to include in the graph.other field that are present as per-atom or per-structure properties on the ase.Atoms objects.

class graph_pes.data.DatasetCollection[source]

Bases: object

A convenience container for training, validation, and optional test sets.

Utilities

class graph_pes.data.ase_db.ASEDatabase(path)[source]

Bases: Sequence[Atoms]

A class that wraps an ASE database file, allowing for indexing into the database to obtain ase.Atoms objects.

We assume that each row contains labels in the data attribute, as a mapping from property names to values, and that units are “standard” ASE units, e.g. eV, eV/Å, etc.

Fully compatible with SchNetPack Dataset Files.

See the ASE documentation for more details about this file format.

Warning

This dataset indexes into a database, performing many random access reads from disk. This can be very slow! If you are using a distributed compute cluster, ensure you copy your database file to somewhere with fast local storage (as opposed to network-attached storage).

Similarly, consider using several workers when loading the dataset, e.g. fitting/loader_kwargs/num_workers=8.

Parameters:

path (str | pathlib.Path) – The path to the database.