Atomic Graphs¶
We describe atomic graphs using the AtomicGraph
class.
For convenient ways to create instances of such graphs from Atoms
objects,
see from_ase()
.
Definition¶
- class graph_pes.AtomicGraph[source]¶
Bases:
NamedTuple
An
AtomicGraph
represents an atomic structure. Each node corresponds to an atom, and each directed edge links a central atom to a “bonded” neighbour.We implement such graphs as (immutable)
NamedTuple
s. This allows for easy serialisation and compatibility with PyTorch, TorchScript and other libraries. These objects, and all functions that operate on them, are compatible with both isolated and periodic structures.Batches of multiple structures are represented by a single
AtomicGraph
object containing multiple disjoint subgraphs. Seeto_batch()
for more information.Properties¶
Below we assume the graph contains
N
atoms,E
edges (andS
structures if the graph is batched).We use two examples to illustrate the various properties:
Water molecule
>>> from ase.build import molecule >>> from load_atoms import view >>> water = molecule("H2O") >>> view(water, show_bonds=True)
We can see that there are 3 atoms, with 2 bonds (and therefore 4 directed edges):
>>> from graph_pes import AtomicGraph >>> water_graph = AtomicGraph.from_ase(water, cutoff=1.2) >>> water_graph AtomicGraph(atoms=3, edges=4, has_cell=False, cutoff=1.2)
Sodium crystal
>>> from ase.build import bulk >>> from load_atoms import view >>> sodium = bulk("Na") >>> view(sodium.repeat(3), show_bonds=True)
This structure has a single atom within a periodic cell. If you look closely, you can see that this atom has 8 nearest neighbours. Only “source” atoms within the unit cell are included in the neighbour list, and hence there are 8 edges:
>>> sodium_graph = AtomicGraph.from_ase(sodium, cutoff=3.7) >>> sodium_graph AtomicGraph(atoms=1, edges=8, has_cell=True, cutoff=3.7)
- Z: Tensor¶
The atomic numbers of the atoms in the graph, of shape
(N,)
.>>> water_graph.Z tensor([8, 1, 1])
>>> sodium_graph.Z tensor([11])
- R: Tensor¶
The cartesian positions of the atoms in the graph, of shape
(N, 3)
.>>> water_graph.R tensor([[ 0.0000, 0.0000, 0.1193], [ 0.0000, 0.7632, -0.4770], [ 0.0000, -0.7632, -0.4770]])
>>> sodium_graph.R tensor([[0., 0., 0.]])
- cell: Tensor¶
The unit cell vectors, of shape
(3, 3)
for a single structure, or(S, 3, 3)
for a batched graph. If the structure is non-periodic, this will be all zeros.>>> water_graph.cell tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
>>> sodium_graph.cell tensor([[-2.1150, 2.1150, 2.1150], [ 2.1150, -2.1150, 2.1150], [ 2.1150, 2.1150, -2.1150]])
- neighbour_list: Tensor¶
A neighbour list, of shape
(2, E)
, wherei, j = graph.neighbour_list[:, k]
is thek
’th directed edge in the graph, linking atomi
to atomj
.>>> water_graph.neighbour_list tensor([[0, 0, 1, 2], [1, 2, 0, 0]])
>>> sodium_graph.neighbour_list tensor([[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0]])
- neighbour_cell_offsets: Tensor¶
The offsets of the neighbours of each atom in units of the unit cell vectors, of shape
(E, 3)
, such that:# k-th edge i, j = graph.neighbour_list[:, k] kth_displacement_vector = ( graph.R[j] + graph.neighbour_cell_offsets[k] @ graph.cell - graph.R[i] )
In the case of an isolated, non-periodic structure, these will be all zeros.
>>> water_graph.neighbour_cell_offsets tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]])
>>> sodium_graph.neighbour_cell_offsets tensor([[ 0., 0., 1.], [ 0., 1., 0.], [ 1., 1., 1.], [ 1., 0., 0.], [-1., 0., 0.], [-1., -1., -1.], [ 0., -1., 0.], [ 0., 0., -1.]])
- properties: Dict[Literal['local_energies', 'forces', 'energy', 'stress', 'virial'], Tensor]¶
A dictionary containing potential energy surface (PES) related properties of the graph.
Key
Shape
Description
"local_energies"
(N,)
contribution to the total energy from each atom
"energy"
()
(
(S,)
if batched)total energy of the structure
"forces"
(N, 3)
force on each atom
"stress"
(3, 3)
(
(S, 3, 3)
if batched)stress tensor (see Theory)
"virial"
(3, 3)
(
(S, 3, 3)
if batched)virial stress tensor (see Theory)
- other: Dict[str, Tensor]¶
A dictionary containing any other additional information about the graph. Feel free to populate this as you wish.
- batch: Tensor | None¶
A tensor of shape
(N,)
indicating the indexi
of the structure within the batch that each atom belongs to. Not present for a single structure.
- ptr: Tensor | None¶
A tensor of shape
(S + 1,)
indicating the indexi
of the first atom in each structure within the batch. Not present for a single structure.
- classmethod from_ase(
- structure,
- cutoff=5.0,
- property_mapping=None,
- others_to_include=None,
Convert an
ase.Atoms
object to anAtomicGraph
.- Parameters:
cutoff (float) – The cutoff distance for neighbour finding.
property_mapping (Mapping[str, Literal['local_energies', 'forces', 'energy', 'stress', 'virial']] | None) – An optional mapping of the form
{key_on_structure: key_for_graph}
defining how relevant properties are labelled on thease.Atoms
object. If not provided, this function will extract all of"energy"
,"forces"
,"stress"
, or"virial"
from the.info
and.arrays
dicts if they are present.others_to_include (Sequence[str] | None) – An optional list of other
.info
/.arrays
keys to include in the graph’sother
dict. The corresponding values will be converted totorch.Tensor
s.
- Return type:
Example
>>> from ase.build import molecule >>> from graph_pes import AtomicGraph >>> # create a a structure with some extra info >>> atoms = molecule("H2O") >>> atoms.info["DFT_energy"] = -10.0 >>> atoms.info["unique_id"] = 1234 >>> # default behaviour: >>> AtomicGraph.from_ase(atoms) AtomicGraph(atoms=3, edges=6, has_cell=False, cutoff=5.0) >>> # specify how to map properties, and other things to include >>> AtomicGraph.from_ase( ... atoms, ... property_mapping={ ... "DFT_energy": "energy", ... }, ... others_to_include=["unique_id"], ... ) AtomicGraph( atoms=3, edges=6, has_cell=False, cutoff=5.0, properties=['energy'], other=['unique_id'] )
- classmethod create_with_defaults(
- Z,
- R,
- cell=None,
- neighbour_list=None,
- neighbour_cell_offsets=None,
- properties=None,
- other=None,
- cutoff=0.0,
Create an
AtomicGraph
, populating missing values with defaults.- Parameters:
Z (Tensor) – The atomic numbers.
R (Tensor) – The cartesian positions.
cell (Tensor | None) – The unit cell. Defaults to
torch.zeros(3, 3)
.neighbour_list (Tensor | None) – The neighbour list. Defaults to
torch.zeros(2, 0)
.neighbour_cell_offsets (Tensor | None) – The neighbour cell offsets. Defaults to
torch.zeros(0, 3)
.properties (Dict[Literal['local_energies', 'forces', 'energy', 'stress', 'virial'], ~torch.Tensor] | None) – The properties. Defaults to
{}
.other (Dict[str, Tensor] | None) – The other information. Defaults to
{}
.
- Return type:
- graph_pes.atomic_graph.replace(
- graph,
- Z=None,
- R=None,
- cell=None,
- neighbour_list=None,
- neighbour_cell_offsets=None,
- properties=None,
- other=None,
- cutoff=None,
A convenience function for replacing the values of an
AtomicGraph
that isTorchScript
compatible (as opposed to the built-in._replace
namedtuple method).- Return type:
Batching¶
A batch of AtomicGraph
instances is itself represented by a single
AtomicGraph
instance, containing multiple disjoint subgraphs.
AtomicGraph
batches are created using to_batch()
:
- graph_pes.atomic_graph.to_batch(graphs)[source]¶
Collate a sequence of atomic graphs into a single batch object.
The
Z
,R
,neighbour_list
, andneighbour_cell_offsets
properties are concatenated along the first axis, while thecell
property is stacked along a new batch dimension.Values in the
"other"
dictionary are concatenated along the first axis if they appear to be a per-atoms property (i.e. their first dimension matches the number of atoms in the structure). Otherwise, they are stacked along a new batch dimension.- Parameters:
graphs (Sequence[AtomicGraph]) – The graphs to collate.
- Return type:
Examples
A basic example:
>>> from ase.build import molecule >>> from graph_pes import AtomicGraph, to_batch >>> graphs = [ ... AtomicGraph.from_ase(molecule("H2O")), ... AtomicGraph.from_ase(molecule("CH4")), ... ] >>> batch = to_batch(graphs) >>> batch.batch # H20 has 3 atoms, CH4 has 5 tensor([0, 0, 0, 1, 1, 1, 1, 1]) >>> batch.ptr # offset of first atom of each graph tensor([0, 3, 8]) >>> batch.Z.shape torch.Size([8]) >>> batch.R.shape torch.Size([8, 3]) >>> batch.cell.shape torch.Size([2, 3, 3])
- graph_pes.atomic_graph.is_batch(graph)[source]¶
Does
graph
represent a batch of atomic graphs?- Parameters:
graph (AtomicGraph) – The graph to check.
- Return type:
If you need to define custom batching logic for a field in the other
property,
you can use register_custom_batcher()
:
- graph_pes.atomic_graph.register_custom_batcher(key)[source]¶
Register a custom batcher for a property in the
other
field.The batcher should conform to the following protocol:
- Parameters:
key (str) – The key of the property to register a custom batcher for.
Examples
>>> from graph_pes.atomic_graph import register_custom_batcher >>> @register_custom_batcher("foo") ... def foo_batcher(batch, values): ... return torch.max(torch.vstack(values), dim=0).values >>> ... # create graphs >>> graphs[0].other["foo"], graphs[1].other["foo"] (tensor([1]), tensor([2])) >>> ... # batch the graphs >>> batch = to_batch(graphs) >>> batch.other["foo"] tensor([2])
Derived Properties¶
We define a number of derived properties of atomic graphs. These
work for both isolated and batched AtomicGraph
instances.
- graph_pes.atomic_graph.number_of_atoms(graph)[source]¶
Get the number of atoms in the
graph
.- Return type:
- graph_pes.atomic_graph.number_of_edges(graph)[source]¶
Get the number of edges in the
graph
.- Return type:
- graph_pes.atomic_graph.has_cell(graph)[source]¶
Does
graph
represent a structure with a defined unit cell?- Return type:
- graph_pes.atomic_graph.neighbour_vectors(graph)[source]¶
Get the vector between each pair of atoms specified in the
graph
’s"neighbour_list"
property, respecting periodic boundary conditions where present.- Return type:
- graph_pes.atomic_graph.neighbour_distances(graph)[source]¶
Get the distance between each pair of atoms specified in the
graph
’sneighbour_list
property, respecting periodic boundary conditions where present.- Return type:
- graph_pes.atomic_graph.number_of_neighbours(graph, include_central_atom=True)[source]¶
Get a tensor,
T
, of shape(N,)
, whereN
is the number of atoms in thegraph
, such thatT[i]
gives the number of neighbours of atomi
. Ifinclude_central_atom
isTrue
, then the central atom is included in the count.- Parameters:
graph (AtomicGraph) – The graph to get the number of neighbours for.
include_central_atom (bool) – Whether to include the central atom in the count.
- Return type:
- graph_pes.atomic_graph.available_properties(graph)[source]¶
Get the labels that are available on the
graph
.
- graph_pes.atomic_graph.number_of_structures(graph)[source]¶
Get the number of structures in the
graph
.- Return type:
- graph_pes.atomic_graph.structure_sizes(batch)[source]¶
Get the number of atoms in each structure in the
batch
, of shape(S,)
whereS
is the number of structures.- Parameters:
batch (AtomicGraph) – The batch to get the structure sizes for.
- Return type:
Examples
>>> len(graphs) 3 >>> [number_of_atoms(g) for g in graphs] [3, 4, 5] >>> structure_sizes(to_batch(graphs)) tensor([3, 4, 5])
Graph Operations¶
We define a number of operations that act on torch.Tensor
instances conditioned on the graph structure.
All of these are fully compatible with batched AtomicGraph
instances, and with TorchScript
compilation.
- graph_pes.atomic_graph.is_local_property(x, graph)[source]¶
Is the property
x
local to each atom in thegraph
?- Parameters:
x (Tensor) – The property to check.
graph (AtomicGraph) – The graph to check the property for.
- Return type:
- graph_pes.atomic_graph.index_over_neighbours(x, graph)[source]¶
Index a per-atom property, \(x\), over the neighbours of each atom in the
graph
.- Return type:
- graph_pes.atomic_graph.sum_over_central_atom_index(p, central_atom_index, graph)[source]¶
Efficient, shape-preserving sum of a property, \(p\), defined over a
central_atom_index
, to get a per-atom property, \(P\), such that:# i in central_atom_index P[i] == torch.sum(p[central_atom_index == i], dim=0) # i not in central_atom_index P[i] == torch.zeros_like(p[0])
See also
sum_over_neighbours()
for the explicit case wherep
is a per-edge property.- Parameters:
p (Tensor) – The property to sum, of shape
(Y, ...)
.central_atom_index (Tensor) – The central atoms relevant to each element of
p
, of shape(Y,)
.graph (AtomicGraph) – The graph to sum the property for.
- Returns:
P – The summed property, of shape
(N, ...)
.- Return type:
- graph_pes.atomic_graph.sum_over_neighbours(p, graph)[source]¶
Shape-preserving sum over neighbours of a per-edge property, \(p_{ij}\), to get a per-atom property, \(P_i\):
\[P_i = \sum_{j \in \mathcal{N}_i} p_{ij}\]where:
\(\mathcal{N}_i\) is the set of neighbours of atom \(i\).
\(p_{ij}\) is the property of the edge between atoms \(i\) and \(j\).
\(p\) is of shape
(E, ...)
and \(P\) is of shape(N, ...)
where \(E\) is the number of edges and \(N\) is the number of atoms....
denotes any number of additional dimensions, including none.\(P_i\) = 0 if \(|\mathcal{N}_i| = 0\).
- Parameters:
p (Tensor) – The per-edge property to sum.
graph (AtomicGraph) – The graph to sum the property for.
- Return type:
- graph_pes.atomic_graph.sum_per_structure(x, graph)[source]¶
Shape-preserving sum of a per-atom property, \(p\), to get a per-structure property, \(P\):
If a single structure, containing
N
atoms, is used, then \(P = \sum_i p_i\), where:\(p_i\) is of shape
(N, ...)
\(P\) is of shape
(...)
...
denotes any number of additional dimensions, includingNone
.
If a batch of
S
structures, containing a total ofN
atoms, is used, then \(P_k = \sum_{k \in K} p_k\), where:\(K\) is the collection of all atoms in structure \(k\)
\(p_i\) is of shape
(N, ...)
\(P\) is of shape
(S, ...)
...
denotes any number of additional dimensions, includingNone
.
- Parameters:
x (Tensor) – The per-atom property to sum.
graph (AtomicGraph) – The graph to sum the property for.
- Return type:
Examples
Single graph case:
>>> import torch >>> from ase.build import molecule >>> from graph_pes.atomic_graph import sum_per_structure, AtomicGraph >>> water = molecule("H2O") >>> graph = AtomicGraph.from_ase(water, cutoff=1.5) >>> # summing over a vector gives a scalar >>> sum_per_structure(torch.ones(3), graph) tensor(3.) >>> # summing over higher order tensors gives a tensor >>> sum_per_structure(torch.ones(3, 2, 3), graph).shape torch.Size([2, 3])
Batch case:
>>> import torch >>> from ase.build import molecule >>> from graph_pes.atomic_graph import sum_per_structure, AtomicGraph, to_batch >>> water = molecule("H2O") >>> graph = AtomicGraph.from_ase(water, cutoff=1.5) >>> batch = to_batch([graph, graph]) >>> batch AtomicGraphBatch(structures: 2, atoms: 6, edges: 8, has_cell: False) >>> # summing over a vector gives a tensor >>> sum_per_structure(torch.ones(6), graph) tensor([3., 3.]) >>> # summing over higher order tensors gives a tensor >>> sum_per_structure(torch.ones(6, 3, 4), graph).shape torch.Size([2, 3, 4])
- graph_pes.atomic_graph.divide_per_atom(x, graph)[source]¶
Divide a per-structure property, \(X\), by the number of atoms in each structure to get a per-atom property, \(x\):
\[x_i = \frac{X_k}{N_k}\]where:
\(X\) is of shape
(S, ...)
\(x\) is of shape
(N, ...)
\(S\) is the number of structures
\(N\) is the number of atoms
- Return type:
- graph_pes.atomic_graph.trim_edges(graph, cutoff)[source]¶
Return a new graph with edges trimmed to be no longer than the
cutoff
. Leaves the original graph unchanged.- Parameters:
graph (AtomicGraph) – The graph to trim the edges of.
cutoff (float) – The maximum distance between atoms to keep the edge.
- Return type:
Three-body operations¶
- graph_pes.utils.threebody.triplet_edge_pairs(graph, three_body_cutoff)[source]¶
Find all the pairs of edges, \(a = (i, j), b = (i, k)\), such that:
\(i, j, k \in \{0, 1, \dots, N-1\}\) are indices of distinct (images of) atoms within the graph
\(j \neq k\)
\(r_{ij} \leq\)
three_body_cutoff
\(r_{ik} \leq\)
three_body_cutoff
- Returns:
edge_pairs – A
(Y, 2)
shaped tensor indicating the edges, such thata, b = edge_pairs[y] i, j = graph.neighbour_list[:,a] i, k = graph.neighbour_list[:,b]
- Return type:
- graph_pes.utils.threebody.triplet_bond_descriptors(graph)[source]¶
For each triplet \((i, j, k)\), get the bond angle \(\theta_{jik}\) (in radians) and the two bond lengths \(r_{ij}\) and \(r_{ik}\).
- Returns:
triplet_idxs – The triplet indices, \((i, j, k)\), of shape
(Y, 3)
.angle – The bond angle \(\theta_{jik}\), shape
(Y,)
.r_ij – The bond length \(r_{ij}\), shape
(Y,)
.r_ik – The bond length \(r_{ik}\), shape
(Y,)
.
- Return type:
Examples
>>> graph = AtomicGraph.from_ase(molecule("H2O")) >>> angle, r_ij, r_ik = triplet_bond_descriptors(graph) >>> torch.rad2deg(angle) tensor([103.9999, 103.9999, 38.0001, 38.0001, 38.0001, 38.0001])