Atomic Graphs

We describe atomic graphs using the AtomicGraph class. For convenient ways to create instances of such graphs from Atoms objects, see from_ase().

Definition

class graph_pes.AtomicGraph[source]

Bases: NamedTuple

An AtomicGraph represents an atomic structure. Each node corresponds to an atom, and each directed edge links a central atom to a “bonded” neighbour.

We implement such graphs as (immutable) NamedTuples. This allows for easy serialisation and compatibility with PyTorch, TorchScript and other libraries. These objects, and all functions that operate on them, are compatible with both isolated and periodic structures.

Batches of multiple structures are represented by a single AtomicGraph object containing multiple disjoint subgraphs. See to_batch() for more information.

Properties

Below we assume the graph contains N atoms, E edges (and S structures if the graph is batched).

We use two examples to illustrate the various properties:

Water molecule
>>> from ase.build import molecule
>>> from load_atoms import view
>>> water = molecule("H2O")
>>> view(water, show_bonds=True)

We can see that there are 3 atoms, with 2 bonds (and therefore 4 directed edges):

>>> from graph_pes import AtomicGraph
>>> water_graph = AtomicGraph.from_ase(water, cutoff=1.2)
>>> water_graph
AtomicGraph(atoms=3, edges=4, has_cell=False, cutoff=1.2)
Sodium crystal
>>> from ase.build import bulk
>>> from load_atoms import view
>>> sodium = bulk("Na")
>>> view(sodium.repeat(3), show_bonds=True)

This structure has a single atom within a periodic cell. If you look closely, you can see that this atom has 8 nearest neighbours. Only “source” atoms within the unit cell are included in the neighbour list, and hence there are 8 edges:

>>> sodium_graph = AtomicGraph.from_ase(sodium, cutoff=3.7)
>>> sodium_graph
AtomicGraph(atoms=1, edges=8, has_cell=True, cutoff=3.7)
Z: Tensor

The atomic numbers of the atoms in the graph, of shape (N,).

>>> water_graph.Z
tensor([8, 1, 1])
>>> sodium_graph.Z
tensor([11])
R: Tensor

The cartesian positions of the atoms in the graph, of shape (N, 3).

>>> water_graph.R
tensor([[ 0.0000,  0.0000,  0.1193],
        [ 0.0000,  0.7632, -0.4770],
        [ 0.0000, -0.7632, -0.4770]])
>>> sodium_graph.R
tensor([[0., 0., 0.]])
cell: Tensor

The unit cell vectors, of shape (3, 3) for a single structure, or (S, 3, 3) for a batched graph. If the structure is non-periodic, this will be all zeros.

>>> water_graph.cell
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
>>> sodium_graph.cell
tensor([[-2.1150,  2.1150,  2.1150],
        [ 2.1150, -2.1150,  2.1150],
        [ 2.1150,  2.1150, -2.1150]])
neighbour_list: Tensor

A neighbour list, of shape (2, E), where i, j = graph.neighbour_list[:, k] is the k’th directed edge in the graph, linking atom i to atom j.

>>> water_graph.neighbour_list
tensor([[0, 0, 1, 2],
        [1, 2, 0, 0]])
>>> sodium_graph.neighbour_list
tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]])
neighbour_cell_offsets: Tensor

The offsets of the neighbours of each atom in units of the unit cell vectors, of shape (E, 3), such that:

# k-th edge
i, j = graph.neighbour_list[:, k]
kth_displacement_vector = (
    graph.R[j]
    + graph.neighbour_cell_offsets[k] @ graph.cell
    - graph.R[i]
)

In the case of an isolated, non-periodic structure, these will be all zeros.

>>> water_graph.neighbour_cell_offsets
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
>>> sodium_graph.neighbour_cell_offsets
tensor([[ 0.,  0.,  1.],
        [ 0.,  1.,  0.],
        [ 1.,  1.,  1.],
        [ 1.,  0.,  0.],
        [-1.,  0.,  0.],
        [-1., -1., -1.],
        [ 0., -1.,  0.],
        [ 0.,  0., -1.]])
properties: Dict[Literal['local_energies', 'forces', 'energy', 'stress', 'virial'], Tensor]

A dictionary containing potential energy surface (PES) related properties of the graph.

Key

Shape

Description

"local_energies"

(N,)

contribution to the total energy from each atom

"energy"

()

((S,) if batched)

total energy of the structure

"forces"

(N, 3)

force on each atom

"stress"

(3, 3)

((S, 3, 3) if batched)

stress tensor (see Theory)

"virial"

(3, 3)

((S, 3, 3) if batched)

virial stress tensor (see Theory)

cutoff: float

The cutoff distance used to create the neighbour list for this graph.

other: Dict[str, Tensor]

A dictionary containing any other additional information about the graph. Feel free to populate this as you wish.

batch: Tensor | None

A tensor of shape (N,) indicating the index i of the structure within the batch that each atom belongs to. Not present for a single structure.

ptr: Tensor | None

A tensor of shape (S + 1,) indicating the index i of the first atom in each structure within the batch. Not present for a single structure.

classmethod from_ase(
structure,
cutoff=5.0,
property_mapping=None,
others_to_include=None,
)[source]

Convert an ase.Atoms object to an AtomicGraph.

Parameters:
  • structure (Atoms) – The ase.Atoms object to convert.

  • cutoff (float) – The cutoff distance for neighbour finding.

  • property_mapping (Mapping[str, Literal['local_energies', 'forces', 'energy', 'stress', 'virial']] | None) – An optional mapping of the form {key_on_structure: key_for_graph} defining how relevant properties are labelled on the ase.Atoms object. If not provided, this function will extract all of "energy", "forces", "stress", or "virial" from the .info and .arrays dicts if they are present.

  • others_to_include (Sequence[str] | None) – An optional list of other .info/.arrays keys to include in the graph’s other dict. The corresponding values will be converted to torch.Tensors.

Return type:

AtomicGraph

Example

>>> from ase.build import molecule
>>> from graph_pes import AtomicGraph

>>> # create a a structure with some extra info
>>> atoms = molecule("H2O")
>>> atoms.info["DFT_energy"] = -10.0
>>> atoms.info["unique_id"] = 1234

>>> # default behaviour:
>>> AtomicGraph.from_ase(atoms)
AtomicGraph(atoms=3, edges=6, has_cell=False, cutoff=5.0)

>>> # specify how to map properties, and other things to include
>>> AtomicGraph.from_ase(
...     atoms,
...     property_mapping={
...         "DFT_energy": "energy",
...     },
...     others_to_include=["unique_id"],
... )
AtomicGraph(
    atoms=3,
    edges=6,
    has_cell=False,
    cutoff=5.0,
    properties=['energy'],
    other=['unique_id']
)
classmethod create_with_defaults(
Z,
R,
cell=None,
neighbour_list=None,
neighbour_cell_offsets=None,
properties=None,
other=None,
cutoff=0.0,
)[source]

Create an AtomicGraph, populating missing values with defaults.

Parameters:
  • Z (Tensor) – The atomic numbers.

  • R (Tensor) – The cartesian positions.

  • cell (Tensor | None) – The unit cell. Defaults to torch.zeros(3, 3).

  • neighbour_list (Tensor | None) – The neighbour list. Defaults to torch.zeros(2, 0).

  • neighbour_cell_offsets (Tensor | None) – The neighbour cell offsets. Defaults to torch.zeros(0, 3).

  • properties (Dict[Literal['local_energies', 'forces', 'energy', 'stress', 'virial'], ~torch.Tensor] | None) – The properties. Defaults to {}.

  • other (Dict[str, Tensor] | None) – The other information. Defaults to {}.

Return type:

AtomicGraph

to(device)[source]

Move this graph to the specified device.

Return type:

AtomicGraph

graph_pes.atomic_graph.replace(
graph,
Z=None,
R=None,
cell=None,
neighbour_list=None,
neighbour_cell_offsets=None,
properties=None,
other=None,
cutoff=None,
)[source]

A convenience function for replacing the values of an AtomicGraph that is TorchScript compatible (as opposed to the built-in ._replace namedtuple method).

Return type:

AtomicGraph

Batching

A batch of AtomicGraph instances is itself represented by a single AtomicGraph instance, containing multiple disjoint subgraphs.

AtomicGraph batches are created using to_batch():

graph_pes.atomic_graph.to_batch(graphs)[source]

Collate a sequence of atomic graphs into a single batch object.

The Z, R, neighbour_list, and neighbour_cell_offsets properties are concatenated along the first axis, while the cell property is stacked along a new batch dimension.

Values in the "other" dictionary are concatenated along the first axis if they appear to be a per-atoms property (i.e. their first dimension matches the number of atoms in the structure). Otherwise, they are stacked along a new batch dimension.

Parameters:

graphs (Sequence[AtomicGraph]) – The graphs to collate.

Return type:

AtomicGraph

Examples

A basic example:

>>> from ase.build import molecule
>>> from graph_pes import AtomicGraph, to_batch
>>> graphs = [
...     AtomicGraph.from_ase(molecule("H2O")),
...     AtomicGraph.from_ase(molecule("CH4")),
... ]
>>> batch = to_batch(graphs)
>>> batch.batch  # H20 has 3 atoms, CH4 has 5
tensor([0, 0, 0, 1, 1, 1, 1, 1])
>>> batch.ptr  # offset of first atom of each graph
tensor([0, 3, 8])
>>> batch.Z.shape
torch.Size([8])
>>> batch.R.shape
torch.Size([8, 3])
>>> batch.cell.shape
torch.Size([2, 3, 3])
graph_pes.atomic_graph.is_batch(graph)[source]

Does graph represent a batch of atomic graphs?

Parameters:

graph (AtomicGraph) – The graph to check.

Return type:

bool

If you need to define custom batching logic for a field in the other property, you can use register_custom_batcher():

graph_pes.atomic_graph.register_custom_batcher(key)[source]

Register a custom batcher for a property in the other field.

The batcher should conform to the following protocol:

class graph_pes.atomic_graph.CustomPropertyBatcher[source]
__call__(batch, values)[source]

Batch the given values.

Parameters:
Return type:

Tensor

Parameters:

key (str) – The key of the property to register a custom batcher for.

Examples

>>> from graph_pes.atomic_graph import register_custom_batcher
>>> @register_custom_batcher("foo")
... def foo_batcher(batch, values):
...     return torch.max(torch.vstack(values), dim=0).values
>>> ... # create graphs
>>> graphs[0].other["foo"], graphs[1].other["foo"]
(tensor([1]), tensor([2]))
>>> ... # batch the graphs
>>> batch = to_batch(graphs)
>>> batch.other["foo"]
tensor([2])

Derived Properties

We define a number of derived properties of atomic graphs. These work for both isolated and batched AtomicGraph instances.

graph_pes.atomic_graph.number_of_atoms(graph)[source]

Get the number of atoms in the graph.

Return type:

int

graph_pes.atomic_graph.number_of_edges(graph)[source]

Get the number of edges in the graph.

Return type:

int

graph_pes.atomic_graph.has_cell(graph)[source]

Does graph represent a structure with a defined unit cell?

Return type:

bool

graph_pes.atomic_graph.neighbour_vectors(graph)[source]

Get the vector between each pair of atoms specified in the graph’s "neighbour_list" property, respecting periodic boundary conditions where present.

Return type:

Tensor

graph_pes.atomic_graph.neighbour_distances(graph)[source]

Get the distance between each pair of atoms specified in the graph’s neighbour_list property, respecting periodic boundary conditions where present.

Return type:

Tensor

graph_pes.atomic_graph.number_of_neighbours(graph, include_central_atom=True)[source]

Get a tensor, T, of shape (N,), where N is the number of atoms in the graph, such that T[i] gives the number of neighbours of atom i. If include_central_atom is True, then the central atom is included in the count.

Parameters:
  • graph (AtomicGraph) – The graph to get the number of neighbours for.

  • include_central_atom (bool) – Whether to include the central atom in the count.

Return type:

Tensor

graph_pes.atomic_graph.available_properties(graph)[source]

Get the labels that are available on the graph.

Return type:

List[Literal[‘local_energies’, ‘forces’, ‘energy’, ‘stress’, ‘virial’]]

graph_pes.atomic_graph.number_of_structures(graph)[source]

Get the number of structures in the graph.

Return type:

int

graph_pes.atomic_graph.structure_sizes(batch)[source]

Get the number of atoms in each structure in the batch, of shape (S,) where S is the number of structures.

Parameters:

batch (AtomicGraph) – The batch to get the structure sizes for.

Return type:

Tensor

Examples

>>> len(graphs)
3
>>> [number_of_atoms(g) for g in graphs]
[3, 4, 5]
>>> structure_sizes(to_batch(graphs))
tensor([3, 4, 5])

Graph Operations

We define a number of operations that act on torch.Tensor instances conditioned on the graph structure. All of these are fully compatible with batched AtomicGraph instances, and with TorchScript compilation.

graph_pes.atomic_graph.is_local_property(x, graph)[source]

Is the property x local to each atom in the graph?

Parameters:
  • x (Tensor) – The property to check.

  • graph (AtomicGraph) – The graph to check the property for.

Return type:

bool

graph_pes.atomic_graph.index_over_neighbours(x, graph)[source]

Index a per-atom property, \(x\), over the neighbours of each atom in the graph.

Return type:

Tensor

graph_pes.atomic_graph.sum_over_central_atom_index(p, central_atom_index, graph)[source]

Efficient, shape-preserving sum of a property, \(p\), defined over a central_atom_index, to get a per-atom property, \(P\), such that:

# i in central_atom_index
P[i] == torch.sum(p[central_atom_index == i], dim=0)
# i not in central_atom_index
P[i] == torch.zeros_like(p[0])

See also

sum_over_neighbours() for the explicit case where p is a per-edge property.

Parameters:
  • p (Tensor) – The property to sum, of shape (Y, ...).

  • central_atom_index (Tensor) – The central atoms relevant to each element of p, of shape (Y,).

  • graph (AtomicGraph) – The graph to sum the property for.

Returns:

P – The summed property, of shape (N, ...).

Return type:

torch.Tensor

graph_pes.atomic_graph.sum_over_neighbours(p, graph)[source]

Shape-preserving sum over neighbours of a per-edge property, \(p_{ij}\), to get a per-atom property, \(P_i\):

\[P_i = \sum_{j \in \mathcal{N}_i} p_{ij}\]

where:

  • \(\mathcal{N}_i\) is the set of neighbours of atom \(i\).

  • \(p_{ij}\) is the property of the edge between atoms \(i\) and \(j\).

  • \(p\) is of shape (E, ...) and \(P\) is of shape (N, ...) where \(E\) is the number of edges and \(N\) is the number of atoms. ... denotes any number of additional dimensions, including none.

  • \(P_i\) = 0 if \(|\mathcal{N}_i| = 0\).

Parameters:
  • p (Tensor) – The per-edge property to sum.

  • graph (AtomicGraph) – The graph to sum the property for.

Return type:

Tensor

graph_pes.atomic_graph.sum_per_structure(x, graph)[source]

Shape-preserving sum of a per-atom property, \(p\), to get a per-structure property, \(P\):

If a single structure, containing N atoms, is used, then \(P = \sum_i p_i\), where:

  • \(p_i\) is of shape (N, ...)

  • \(P\) is of shape (...)

  • ... denotes any number of additional dimensions, including None.

If a batch of S structures, containing a total of N atoms, is used, then \(P_k = \sum_{k \in K} p_k\), where:

  • \(K\) is the collection of all atoms in structure \(k\)

  • \(p_i\) is of shape (N, ...)

  • \(P\) is of shape (S, ...)

  • ... denotes any number of additional dimensions, including None.

Parameters:
  • x (Tensor) – The per-atom property to sum.

  • graph (AtomicGraph) – The graph to sum the property for.

Return type:

Tensor

Examples

Single graph case:

>>> import torch
>>> from ase.build import molecule
>>> from graph_pes.atomic_graph import sum_per_structure, AtomicGraph
>>> water = molecule("H2O")
>>> graph = AtomicGraph.from_ase(water, cutoff=1.5)
>>> # summing over a vector gives a scalar
>>> sum_per_structure(torch.ones(3), graph)
tensor(3.)
>>> # summing over higher order tensors gives a tensor
>>> sum_per_structure(torch.ones(3, 2, 3), graph).shape
torch.Size([2, 3])

Batch case:

>>> import torch
>>> from ase.build import molecule
>>> from graph_pes.atomic_graph import sum_per_structure, AtomicGraph, to_batch
>>> water = molecule("H2O")
>>> graph = AtomicGraph.from_ase(water, cutoff=1.5)
>>> batch = to_batch([graph, graph])
>>> batch
AtomicGraphBatch(structures: 2, atoms: 6, edges: 8, has_cell: False)
>>> # summing over a vector gives a tensor
>>> sum_per_structure(torch.ones(6), graph)
tensor([3., 3.])
>>> # summing over higher order tensors gives a tensor
>>> sum_per_structure(torch.ones(6, 3, 4), graph).shape
torch.Size([2, 3, 4])
graph_pes.atomic_graph.divide_per_atom(x, graph)[source]

Divide a per-structure property, \(X\), by the number of atoms in each structure to get a per-atom property, \(x\):

\[x_i = \frac{X_k}{N_k}\]

where:

  • \(X\) is of shape (S, ...)

  • \(x\) is of shape (N, ...)

  • \(S\) is the number of structures

  • \(N\) is the number of atoms

Return type:

Tensor

graph_pes.atomic_graph.trim_edges(graph, cutoff)[source]

Return a new graph with edges trimmed to be no longer than the cutoff. Leaves the original graph unchanged.

Parameters:
  • graph (AtomicGraph) – The graph to trim the edges of.

  • cutoff (float) – The maximum distance between atoms to keep the edge.

Return type:

AtomicGraph

Three-body operations

graph_pes.utils.threebody.triplet_edge_pairs(graph, three_body_cutoff)[source]

Find all the pairs of edges, \(a = (i, j), b = (i, k)\), such that:

  • \(i, j, k \in \{0, 1, \dots, N-1\}\) are indices of distinct (images of) atoms within the graph

  • \(j \neq k\)

  • \(r_{ij} \leq\) three_body_cutoff

  • \(r_{ik} \leq\) three_body_cutoff

Returns:

edge_pairs – A (Y, 2) shaped tensor indicating the edges, such that

a, b = edge_pairs[y]
i, j = graph.neighbour_list[:,a]
i, k = graph.neighbour_list[:,b]

Return type:

torch.Tensor

graph_pes.utils.threebody.triplet_bond_descriptors(graph)[source]

For each triplet \((i, j, k)\), get the bond angle \(\theta_{jik}\) (in radians) and the two bond lengths \(r_{ij}\) and \(r_{ik}\).

Returns:

  • triplet_idxs – The triplet indices, \((i, j, k)\), of shape (Y, 3).

  • angle – The bond angle \(\theta_{jik}\), shape (Y,).

  • r_ij – The bond length \(r_{ij}\), shape (Y,).

  • r_ik – The bond length \(r_{ik}\), shape (Y,).

Return type:

tuple[Tensor, Tensor, Tensor, Tensor]

Examples

>>> graph = AtomicGraph.from_ase(molecule("H2O"))
>>> angle, r_ij, r_ik = triplet_bond_descriptors(graph)
>>> torch.rad2deg(angle)
tensor([103.9999, 103.9999,  38.0001,  38.0001,  38.0001,  38.0001])