Dataset

The main entry point of load-atoms is the load_dataset function:

load_atoms.load_dataset(thing: str | list[ase.Atoms] | Path, root: str | Path | None = None) AtomsDataset[source]

Load a dataset by name or from a list of structures.

Parameters:
  • thing – A dataset id, a list of structures, or a path to a file.

  • root – The root directory to use when loading a dataset by id. If not provided, the default root directory (~/.load-atoms) will be used.

Examples

Load a dataset by id:

>>> from load_atoms import load_dataset
>>> dataset = load_dataset("QM9")
╭───────────────────────────────── QM9 ─────────────────────────────────╮
│                                                                       │
│   Downloading dsgdb9nsd.xyz.tar.bz2 ━━━━━━━━━━━━━━━━━━━━ 100% 00:09   │
│   Extracting dsgdb9nsd.xyz.tar.bz2  ━━━━━━━━━━━━━━━━━━━━ 100% 00:18   │
│   Processing files                  ━━━━━━━━━━━━━━━━━━━━ 100% 00:19   │
│   Caching to disk                   ━━━━━━━━━━━━━━━━━━━━ 100% 00:02   │
│                                                                       │
│            The QM9 dataset is covered by the CC0 license.             │
│        Please cite the QM9 dataset if you use it in your work.        │
│          For more information about the QM9 dataset, visit:           │
│                            load-atoms/QM9                             │
╰───────────────────────────────────────────────────────────────────────╯
>>> dataset
QM9:
    structures: 133,885
    atoms: 2,407,753
    species:
        H: 51.09%
        C: 35.16%
        O: 7.81%
        N: 5.80%
        F: 0.14%
    properties:
        per atom: (partial_charges)
        per structure: (
            A, B, C, Cv, G, H, U, U0, alpha,
            frequencies, gap, geometry, homo, inchi, index,
            lumo, mu, r2, smiles, zpve
        )

Optionally save a dataset to an explicit root directory:

>>> load_dataset("QM9", root="./my-datasets")

Wrap a list of structures in a dataset:

>>> load_dataset([Atoms("H2O"), Atoms("H2O2")])

Load a dataset from a file:

>>> load_dataset("path/to/file.xyz")

Note

As of ase==0.3.9, the "energy", "forces", and "stress" special keys are loaded into a SinglePointCalculator object, and removed from the .info and .arrays dictionaries on the atoms object. We reverse this process when loading a dataset from file.

class load_atoms.AtomsDataset[source]

An abstract base class for datasets of ase.Atoms objects.

This class provides a common interface for interacting with datasets of atomic structures, abstracting over the underlying storage mechanism.

The two current concrete implementations are InMemoryAtomsDataset and LmdbAtomsDataset.

abstract property structure_sizes: ndarray

An array containing the number of atoms in each structure, such that:

for idx, structure in enumerate(dataset):
    assert len(structure) == dataset.structure_sizes[idx]
abstract __len__() int[source]

The number of structures in the dataset.

__getitem__(index: int) Atoms[source]
__getitem__(index: list[int] | list[bool] | np.ndarray | slice) Self

Get the structure(s) at the given index(es).

If a single int is provided, the corresponding structure is returned:

>>> QM7 = load_dataset("QM7")
>>> QM7[0]
Atoms(symbols='CH4', pbc=False)

If a slice is provided, a new AtomsDataset is returned containing the structures in the slice:

>>> QM7[:5]
Dataset:
    structures: 5
    atoms: 32
    species:
        H: 68.75%
        C: 28.12%
        O: 3.12%
    properties:
        per atom: ()
        per structure: (energy)

If a list or numpy.ndarray of ints is provided, a new AtomsDataset is returned containing the structures at the given indices:

>>> len(QM7[[0, 2, 4]])
3

If a list or numpy.ndarray of bools is provided with the same length as the dataset, a new AtomsDataset is returned containing the structures where the boolean is True (see also filter_by()):

>>> bool_idx = [
...     len(s) > 10 for s in QM7
... ]
>>> len(QM7[bool_idx]) == sum(bool_idx)
True
Parameters:

index – The index(es) to get the structure(s) at.

write(path: Path | str, format: str | None = None, append: bool = False, **kwargs: Any)[source]

Write the dataset to a file, using ase.io.write().

Parameters:
  • path – The path to write the dataset to.

  • format – The format to write the dataset in.

  • append – Whether to append to the file.

  • kwargs – Additional keyword arguments to pass to ase.io.write().

abstract property info: Mapping[str, Any]

Get a mapping from keys that are shared across all structures’ .info attributes to the concatenated corresponding values.

The returned mapping conforms to:

for key, value in dataset.info.items():
    for i, structure in enumerate(dataset):
        assert structure.info[key] == value[i]
abstract property arrays: Mapping[str, ndarray]

Get a mapping from each structure’s .arrays keys to arrays.

The returned mapping conforms to:

for key, value in dataset.arrays.items():
    assert value.shape[0] == dataset.n_atoms
    assert value == np.vstack(
        [structure.arrays[key] for structure in dataset]
    )
abstract classmethod save(path: Path, structures: Iterable[Atoms], description: DatabaseEntry | None = None)[source]

Save the dataset to a file.

Parameters:
  • path – The path to save the dataset to.

  • structures – The structures to save to the dataset.

  • description – The description of the dataset.

abstract classmethod load(path: Path) Self[source]

Load the dataset from a file.

Parameters:

path – The path to load the dataset from.

species_counts() Mapping[str, int][source]

Get the number of atoms of each species in the dataset.

property n_atoms: int

The total number of atoms in the dataset.

This is equivalent to the sum of the number of atoms in each structure.

filter_by(*functions: Callable[[Atoms], bool], **info_kwargs: Any) Self[source]

Return a new dataset containing only the structures that match the given criteria.

Parameters:
  • functions – Functions to filter the dataset by. Each function should take an ASE Atoms object as input and return a boolean.

  • info_kwargs – Keyword arguments to filter the dataset by. Only atoms objects with matching info keys and values will be returned.

Example

Get small, amorphous structures with large forces:

>>> from load_atoms import load_dataset
>>> dataset = load_dataset("C-GAP-17")
>>> dataset.filter_by(
...     lambda structure: len(structure) < 50,
...     lambda structure: structure.arrays["force"].max() > 5,
...     config_type="bulk_amo"
... )
Dataset:
    structures: 609
    atoms: 23,169
    species:
        C: 100.00%
    properties:
        per atom: (force)
        per structure: (config_type, detailed_ct, split, energy)
random_split(splits: Sequence[float] | Sequence[int], seed: int = 42, keep_ratio: str | None = None) list[Self][source]

Randomly split the dataset into multiple, disjoint parts.

Parameters:
  • splits – The number of structures to put in each split. If a list of floats, the splits will be calculated as a fraction of the dataset size.

  • seed – The random seed to use for shuffling the dataset.

  • keep_ratio – If not None, splits will be generated to maintain the ratio of structures in each split with the specified .info value.

Returns:

A list of new datasets, each containing a subset of the original

Return type:

list[Self]

Examples

Split a dataset into 80% training and 20% test sets:

>>> train, test = dataset.random_split([0.8, 0.2])

Split a dataset into 3 parts:

>>> train, val, test = dataset.random_split([1_000, 100, 100])

Maintain the ratio of config_type values in each split:

>>> from load_atoms import load_dataset
>>> import numpy as np
>>> # helper function
>>> def ratios(thing):
...     values, counts = np.unique(thing, return_counts=True)
...     max_len = max(len(str(v)) for v in values)
...     for v, c in zip(values, counts / counts.sum()):
...         print(f"{v:>{max_len}}: {c:>6.2%}")
...
>>> dataset = load_dataset("C-GAP-17")
>>> ratios(dataset.info["config_type"])
  bulk_amo: 75.28%
bulk_cryst:  8.83%
     dimer:  0.66%
  surf_amo: 15.23%
>>> train, val, test = dataset.random_split(
...     [0.6, 0.2, 0.2],
...     keep_ratio="config_type"
... )
>>> ratios(train.info["config_type"])
  bulk_amo: 75.28%
bulk_cryst:  8.83%
     dimer:  0.66%
  surf_amo: 15.23%
k_fold_split(k: int = 5, fold: int = 0, shuffle: bool = True, seed: int = 42, keep_ratio: str | None = None) tuple[Self, Self][source]

Generate (an optionally shuffled) train/test split for cross-validation.

Parameters:
  • k – The number of folds to use.

  • fold – The fold to use for testing.

  • shuffle – Whether to shuffle the dataset before splitting.

  • seed – The random seed to use for shuffling the dataset.

  • keep_ratio – If not None, splits will be generated to maintain the ratio of structures in each split with the specified .info value.

Returns:

The train and test datasets.

Return type:

Tuple[Self, Self]

Example

Basic usage:

>>> for i in range(5):
...     train, test = dataset.k_fold_split(k=5, fold=i)
...     ...  # do something, e.g. train a model

Maintain the ratio of config_type values in each split (see also random_split() for a more detailed example of this feature):

>>> train, test = dataset.k_fold_split(
...     k=5, fold=0, keep_ratio="config_type"
... )
class load_atoms.atoms_dataset.InMemoryAtomsDataset[source]

An in-memory implementation of AtomsDataset.

Internally, this class wraps a list of ase.Atoms objects, all of which are stored in RAM. Suitable for small to moderately large datasets.

class load_atoms.atoms_dataset.LmdbAtomsDataset[source]

An LMDB-backed implementation of AtomsDataset.

Internally, this class wraps an lmdb.Environment object, which stores the dataset in an LMDB database. Suitable for large datasets that cannot fit in memory. Accessing data from this dataset type is (marginally) slower than for InMemoryAtomsDatasets, but allows for efficient processing of extremely large datasets that cannot otherwise fit in memory.

Warning

The ase.Atoms objects in an LMDB dataset are read-only. Modifying the .info or .arrays of an ase.Atoms object will have no effect, and will instead throw an error.