Source code for graph_pes.atomic_graph

import warnings
from typing import (
    TYPE_CHECKING,
    Dict,
    Final,
    List,
    Literal,
    Mapping,
    NamedTuple,
    Optional,
    Protocol,
    Sequence,
    Union,
    cast,
)

import ase
import numpy as np
import torch
import torch.multiprocessing
import torch.utils.data
import vesin
from ase.stress import voigt_6_to_full_3x3_stress
from load_atoms.utils import remove_calculator
from typing_extensions import TypeAlias

from graph_pes.utils.misc import (
    all_equal,
    is_being_documented,
    left_aligned_div,
    left_aligned_mul,
    to_significant_figures,
    uniform_repr,
)

DEFAULT_CUTOFF: Final[float] = 5.0


PropertyKey: TypeAlias = Literal[
    "local_energies", "forces", "energy", "stress", "virial"
]
ALL_PROPERTY_KEYS: Final[List[PropertyKey]] = [
    "local_energies",
    "forces",
    "energy",
    "stress",
    "virial",
]

if not TYPE_CHECKING and not is_being_documented():
    # torchscript doesn't handle TypedDicts or Literal types:
    # at run-time, we just use less specific, but still correct, types
    Properties: TypeAlias = Dict[str, torch.Tensor]
    PropertyKey: TypeAlias = str


[docs] class AtomicGraph(NamedTuple): r""" An :class:`AtomicGraph` represents an atomic structure. Each node corresponds to an atom, and each directed edge links a central atom to a "bonded" neighbour. We implement such graphs as (immutable) :class:`~typing.NamedTuple`\ s. This allows for easy serialisation and compatibility with PyTorch, TorchScript and other libraries. These objects, and all functions that operate on them, are compatible with both isolated and periodic structures. Batches of multiple structures are represented by a single :class:`AtomicGraph` object containing multiple disjoint subgraphs. See :func:`~graph_pes.atomic_graph.to_batch` for more information. Properties ++++++++++ Below we assume the graph contains ``N`` atoms, ``E`` edges (and ``S`` structures if the graph is batched). We use two examples to illustrate the various properties: .. dropdown:: Water molecule .. code-block:: python >>> from ase.build import molecule >>> from load_atoms import view >>> water = molecule("H2O") >>> view(water, show_bonds=True) .. raw:: html :file: ../../docs/source/_static/water.html We can see that there are 3 atoms, with 2 bonds (and therefore 4 directed edges): .. code-block:: python >>> from graph_pes import AtomicGraph >>> water_graph = AtomicGraph.from_ase(water, cutoff=1.2) >>> water_graph AtomicGraph(atoms=3, edges=4, has_cell=False, cutoff=1.2) .. dropdown:: Sodium crystal .. code-block:: python >>> from ase.build import bulk >>> from load_atoms import view >>> sodium = bulk("Na") >>> view(sodium.repeat(3), show_bonds=True) .. raw:: html :file: ../../docs/source/_static/Na.html This structure has a single atom within a periodic cell. If you look closely, you can see that this atom has 8 nearest neighbours. Only "source" atoms within the unit cell are included in the neighbour list, and hence there are 8 edges: .. code-block:: python >>> sodium_graph = AtomicGraph.from_ase(sodium, cutoff=3.7) >>> sodium_graph AtomicGraph(atoms=1, edges=8, has_cell=True, cutoff=3.7) """ Z: torch.Tensor """ The atomic numbers of the atoms in the graph, of shape ``(N,)``. .. code-block:: python >>> water_graph.Z tensor([8, 1, 1]) .. code-block:: python >>> sodium_graph.Z tensor([11]) """ R: torch.Tensor """ The cartesian positions of the atoms in the graph, of shape ``(N, 3)``. .. code-block:: python >>> water_graph.R tensor([[ 0.0000, 0.0000, 0.1193], [ 0.0000, 0.7632, -0.4770], [ 0.0000, -0.7632, -0.4770]]) .. code-block:: python >>> sodium_graph.R tensor([[0., 0., 0.]]) """ cell: torch.Tensor """ The unit cell vectors, of shape ``(3, 3)`` for a single structure, or ``(S, 3, 3)`` for a batched graph. If the structure is non-periodic, this will be all zeros. .. code-block:: python >>> water_graph.cell tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) .. code-block:: python >>> sodium_graph.cell tensor([[-2.1150, 2.1150, 2.1150], [ 2.1150, -2.1150, 2.1150], [ 2.1150, 2.1150, -2.1150]]) """ neighbour_list: torch.Tensor """ A neighbour list, of shape ``(2, E)``, where ``i, j = graph.neighbour_list[:, k]`` is the ``k``'th directed edge in the graph, linking atom ``i`` to atom ``j``. .. code-block:: python >>> water_graph.neighbour_list tensor([[0, 0, 1, 2], [1, 2, 0, 0]]) .. code-block:: python >>> sodium_graph.neighbour_list tensor([[0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0]]) """ neighbour_cell_offsets: torch.Tensor """ The offsets of the neighbours of each atom in units of the unit cell vectors, of shape ``(E, 3)``, such that: .. code-block:: python # k-th edge i, j = graph.neighbour_list[:, k] kth_displacement_vector = ( graph.R[j] + graph.neighbour_cell_offsets[k] @ graph.cell - graph.R[i] ) In the case of an isolated, non-periodic structure, these will be all zeros. .. code-block:: python >>> water_graph.neighbour_cell_offsets tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) .. code-block:: python >>> sodium_graph.neighbour_cell_offsets tensor([[ 0., 0., 1.], [ 0., 1., 0.], [ 1., 1., 1.], [ 1., 0., 0.], [-1., 0., 0.], [-1., -1., -1.], [ 0., -1., 0.], [ 0., 0., -1.]]) """ properties: Dict[PropertyKey, torch.Tensor] """ A dictionary containing potential energy surface (PES) related properties of the graph. .. list-table:: :header-rows: 1 * - Key - Shape - Description * - :code:`"local_energies"` - :code:`(N,)` - contribution to the total energy from each atom * - :code:`"energy"` - ``()`` (``(S,)`` if batched) - total energy of the structure * - :code:`"forces"` - :code:`(N, 3)` - force on each atom * - :code:`"stress"` - ``(3, 3)`` (``(S, 3, 3)`` if batched) - stress tensor (see :doc:`../theory`) * - :code:`"virial"` - ``(3, 3)`` (``(S, 3, 3)`` if batched) - virial stress tensor (see :doc:`../theory`) """ cutoff: float """ The cutoff distance used to create the neighbour list for this graph. """ other: Dict[str, torch.Tensor] """ A dictionary containing any other additional information about the graph. Feel free to populate this as you wish. """ batch: Union[torch.Tensor, None] = None """ A tensor of shape ``(N,)`` indicating the index ``i`` of the structure within the batch that each atom belongs to. Not present for a single structure. """ ptr: Union[torch.Tensor, None] = None """ A tensor of shape ``(S + 1,)`` indicating the index ``i`` of the first atom in each structure within the batch. Not present for a single structure. """
[docs] @classmethod def from_ase( cls, structure: ase.Atoms, cutoff: float = DEFAULT_CUTOFF, property_mapping: Union[Mapping[str, PropertyKey], None] = None, others_to_include: Union[Sequence[str], None] = None, ) -> "AtomicGraph": r""" Convert an :class:`ase.Atoms` object to an :class:`AtomicGraph`. Parameters ---------- structure The :class:`ase.Atoms` object to convert. cutoff The cutoff distance for neighbour finding. property_mapping An optional mapping of the form ``{key_on_structure: key_for_graph}`` defining how relevant properties are labelled on the :class:`ase.Atoms` object. If not provided, this function will extract all of ``"energy"``, ``"forces"``, ``"stress"``, or ``"virial"`` from the ``.info`` and ``.arrays`` dicts if they are present. others_to_include An optional list of other ``.info``/``.arrays`` keys to include in the graph's ``other`` dict. The corresponding values will be converted to :class:`torch.Tensor`\ s. Example ------- .. code-block:: python >>> from ase.build import molecule >>> from graph_pes import AtomicGraph >>> # create a a structure with some extra info >>> atoms = molecule("H2O") >>> atoms.info["DFT_energy"] = -10.0 >>> atoms.info["unique_id"] = 1234 >>> # default behaviour: >>> AtomicGraph.from_ase(atoms) AtomicGraph(atoms=3, edges=6, has_cell=False, cutoff=5.0) >>> # specify how to map properties, and other things to include >>> AtomicGraph.from_ase( ... atoms, ... property_mapping={ ... "DFT_energy": "energy", ... }, ... others_to_include=["unique_id"], ... ) AtomicGraph( atoms=3, edges=6, has_cell=False, cutoff=5.0, properties=['energy'], other=['unique_id'] ) """ # account for strange behaviour in ase 3.23.0+ whereby # properties are sometimes removed from the atoms.info/arrays dicts # to ensure we don't change the atoms object that the user passes in # we first make a copy, ensuring that the calculator is also copied over calc = structure.calc structure = structure.copy() structure.calc = calc # then we remove the calculator and copy over the properties to the # relevant info/arrays dicts (with help from load_atoms.utils) remove_calculator(structure) _float = torch.get_default_dtype() # structure Z = torch.tensor(structure.numbers, dtype=torch.long) R = torch.tensor(structure.positions, dtype=_float) cell = torch.tensor(structure.cell.array, dtype=_float) # neighbour list i, j, offsets = vesin.ase_neighbor_list("ijS", structure, float(cutoff)) i = i.astype(np.int64) j = j.astype(np.int64) neighbour_list = torch.tensor(np.vstack([i, j]), dtype=torch.long) neighbour_cell_offsets = torch.tensor(offsets, dtype=_float) # properties properties: dict[PropertyKey, torch.Tensor] = {} other: dict[str, torch.Tensor] = {} if property_mapping is None: all_keys = set(structure.info) | set(structure.arrays) property_mapping = { k: cast(PropertyKey, k) for k in ["energy", "forces", "stress", "virial"] if k in all_keys } if others_to_include is None: others_to_include = [] def to_tensor(value): t = torch.tensor(value) if t.is_floating_point(): t = t.to(_float) return t for key, value in list(structure.info.items()) + list( structure.arrays.items() ): if key in property_mapping: property = property_mapping[key] # ensure stress is always 3x3, not voigt notation if property in ["stress", "virial"] and value.reshape( -1 ).shape == (6,): value = voigt_6_to_full_3x3_stress(value) properties[property] = to_tensor(value) elif key in others_to_include: other[key] = to_tensor(value) missing = set( structure_key for structure_key, graph_key in property_mapping.items() if graph_key not in properties ) if missing: raise ValueError(f"Unable to find properties: {missing}") return cls.create_with_defaults( Z=Z, R=R, cell=cell, neighbour_list=neighbour_list, neighbour_cell_offsets=neighbour_cell_offsets, properties=properties, other=other, cutoff=cutoff, )
[docs] @classmethod def create_with_defaults( cls, Z: torch.Tensor, R: torch.Tensor, cell: Union[torch.Tensor, None] = None, neighbour_list: Union[torch.Tensor, None] = None, neighbour_cell_offsets: Union[torch.Tensor, None] = None, properties: Union[Dict[PropertyKey, torch.Tensor], None] = None, other: Union[Dict[str, torch.Tensor], None] = None, cutoff: float = 0.0, ) -> "AtomicGraph": """ Create an :class:`AtomicGraph`, populating missing values with defaults. Parameters ---------- Z The atomic numbers. R The cartesian positions. cell The unit cell. Defaults to ``torch.zeros(3, 3)``. neighbour_list The neighbour list. Defaults to ``torch.zeros(2, 0)``. neighbour_cell_offsets The neighbour cell offsets. Defaults to ``torch.zeros(0, 3)``. properties The properties. Defaults to ``{}``. other The other information. Defaults to ``{}``. """ if cell is None: cell = torch.zeros(3, 3, device=R.device).float() if neighbour_list is None: neighbour_list = torch.zeros(2, 0, device=R.device).long() if neighbour_cell_offsets is None: neighbour_cell_offsets = torch.zeros(0, 3, device=R.device).float() if properties is None: properties = {} if other is None: other = {} return cls( Z=Z, R=R, cell=cell, neighbour_list=neighbour_list, neighbour_cell_offsets=neighbour_cell_offsets, properties=properties, other=other, cutoff=cutoff, )
[docs] def to(self, device: Union[torch.device, str]) -> "AtomicGraph": """Move this graph to the specified device.""" properties: dict[PropertyKey, torch.Tensor] = { k: v.to(device) for k, v in self.properties.items() } return AtomicGraph( Z=self.Z.to(device), R=self.R.to(device), cell=self.cell.to(device), neighbour_list=self.neighbour_list.to(device), neighbour_cell_offsets=self.neighbour_cell_offsets.to(device), properties=properties, other={k: v.to(device) for k, v in self.other.items()}, cutoff=self.cutoff, batch=self.batch.to(device) if self.batch is not None else None, ptr=self.ptr.to(device) if self.ptr is not None else None, )
def __repr__(self): info = {} if self.batch is not None: name = "AtomicGraphBatch" info["structures"] = self.batch.max().item() + 1 else: name = "AtomicGraph" info["atoms"] = number_of_atoms(self) info["edges"] = number_of_edges(self) info["has_cell"] = has_cell(self) info["cutoff"] = to_significant_figures(self.cutoff, 3) if self.properties: info["properties"] = available_properties(self) if self.other: info["other"] = list(self.other.keys()) return uniform_repr(name, **info, indent_width=4)
[docs] def replace( graph: AtomicGraph, Z: Optional[torch.Tensor] = None, R: Optional[torch.Tensor] = None, cell: Optional[torch.Tensor] = None, neighbour_list: Optional[torch.Tensor] = None, neighbour_cell_offsets: Optional[torch.Tensor] = None, properties: Optional[dict[PropertyKey, torch.Tensor]] = None, other: Optional[dict[str, torch.Tensor]] = None, cutoff: Optional[float] = None, ) -> AtomicGraph: """ A convenience function for replacing the values of an :class:`AtomicGraph` that is ``TorchScript`` compatible (as opposed to the built-in ``._replace`` namedtuple method). """ return AtomicGraph( Z=Z if Z is not None else graph.Z, R=R if R is not None else graph.R, cell=cell if cell is not None else graph.cell, neighbour_list=neighbour_list if neighbour_list is not None else graph.neighbour_list, neighbour_cell_offsets=neighbour_cell_offsets if neighbour_cell_offsets is not None else graph.neighbour_cell_offsets, properties=properties if properties is not None else graph.properties, other=other if other is not None else graph.other, cutoff=cutoff if cutoff is not None else graph.cutoff, batch=graph.batch, ptr=graph.ptr, )
############################### BATCHING ###############################
[docs] class CustomPropertyBatcher(Protocol):
[docs] def __call__( self, batch: AtomicGraph, values: list[torch.Tensor] ) -> torch.Tensor: """ Batch the given values. Parameters ---------- batch The batch of graphs. values The list of values to batch. """ ...
_custom_batchers: dict[str, CustomPropertyBatcher] = {} # NB this is essential, otherwise all data loader workers will # have an empty _custom_batchers dict, and hence fail to perform # any custom batching torch.multiprocessing.set_start_method("fork", force=True)
[docs] def register_custom_batcher(key: str): """ Register a custom batcher for a property in the ``other`` field. The batcher should conform to the following protocol: .. autoclass:: graph_pes.atomic_graph.CustomPropertyBatcher() :members: __call__ Parameters ---------- key The key of the property to register a custom batcher for. Examples -------- >>> from graph_pes.atomic_graph import register_custom_batcher >>> @register_custom_batcher("foo") ... def foo_batcher(batch, values): ... return torch.max(torch.vstack(values), dim=0).values >>> ... # create graphs >>> graphs[0].other["foo"], graphs[1].other["foo"] (tensor([1]), tensor([2])) >>> ... # batch the graphs >>> batch = to_batch(graphs) >>> batch.other["foo"] tensor([2]) """ def decorator(func: CustomPropertyBatcher): _custom_batchers[key] = func return func return decorator
[docs] def to_batch( graphs: Sequence[AtomicGraph], ) -> AtomicGraph: """ Collate a sequence of atomic graphs into a single batch object. The ``Z``, ``R``, ``neighbour_list``, and ``neighbour_cell_offsets`` properties are concatenated along the first axis, while the ``cell`` property is stacked along a new batch dimension. Values in the ``"other"`` dictionary are concatenated along the first axis if they appear to be a per-atoms property (i.e. their first dimension matches the number of atoms in the structure). Otherwise, they are stacked along a new batch dimension. Parameters ---------- graphs The graphs to collate. Examples -------- A basic example: >>> from ase.build import molecule >>> from graph_pes import AtomicGraph, to_batch >>> graphs = [ ... AtomicGraph.from_ase(molecule("H2O")), ... AtomicGraph.from_ase(molecule("CH4")), ... ] >>> batch = to_batch(graphs) >>> batch.batch # H20 has 3 atoms, CH4 has 5 tensor([0, 0, 0, 1, 1, 1, 1, 1]) >>> batch.ptr # offset of first atom of each graph tensor([0, 3, 8]) >>> batch.Z.shape torch.Size([8]) >>> batch.R.shape torch.Size([8, 3]) >>> batch.cell.shape torch.Size([2, 3, 3]) """ if any(is_batch(g) for g in graphs): raise ValueError("Cannot recursively batch graphs") # easy properties: just cat these together Z = torch.cat([g.Z for g in graphs]) R = torch.cat([g.R for g in graphs]) neighbour_offsets = torch.cat([g.neighbour_cell_offsets for g in graphs]) # stack cells along a new batch dimension if not all_equal([has_cell(g) for g in graphs]): warnings.warn( "Attempting to batch a colleciton of graphs where only some " "have a defined unit cell. This may lead to unexpected results.", stacklevel=2, ) cells = torch.stack([g.cell for g in graphs]) # standard way to caculaute the batch and ptr properties batch = torch.cat( [torch.full_like(g.Z, fill_value=i) for i, g in enumerate(graphs)] ) ptr = torch.tensor([0] + [g.Z.shape[0] for g in graphs]).cumsum(dim=0) # use the ptr to increment the neighbour index appropriately neighbour_list = torch.cat( [g.neighbour_list + ptr[i] for i, g in enumerate(graphs)], dim=1 ) # handle cutoff cutoffs = [g.cutoff for g in graphs] if not all_equal(cutoffs): warnings.warn( "Attempting to batch graphs with different cutoffs: " f"{cutoffs}. Setting graph.cutoff to the maximum.", stacklevel=2, ) cutoff = max(cutoffs) properties: dict[PropertyKey, torch.Tensor] = {} # - per structure labels are concatenated along a new batch axis (0) for key in ["energy", "stress", "virial"]: key = cast(PropertyKey, key) if all(key in g.properties for g in graphs): properties[key] = torch.stack([g.properties[key] for g in graphs]) # - per atom labels are concatenated along the first axis for key in ["forces", "local_energies"]: key = cast(PropertyKey, key) if all(key in g.properties for g in graphs): properties[key] = torch.cat([g.properties[key] for g in graphs]) batched_graph = AtomicGraph( Z=Z, R=R, cell=cells, neighbour_list=neighbour_list, neighbour_cell_offsets=neighbour_offsets, properties=properties, other={}, cutoff=cutoff, batch=batch, ptr=ptr, ) # - finally, add in the other stuff: this is a bit tricky # since we need to try and infer whether these are per-atom # or per-structure for key in graphs[0].other: values = [g.other[key] for g in graphs] if key in _custom_batchers: batcher = _custom_batchers[key] batched_graph.other[key] = batcher(batched_graph, values) elif all(is_local_property(g.other[key], g) for g in graphs): batched_graph.other[key] = torch.cat(values) else: batched_graph.other[key] = torch.stack(values) return batched_graph
[docs] def is_batch(graph: AtomicGraph) -> bool: """ Does ``graph`` represent a batch of atomic graphs? Parameters ---------- graph The graph to check. """ return graph.batch is not None
############################### PROPERTIES ############################### def get_cell_volume(graph: AtomicGraph) -> float: """ Get the volume of the unit cell. """ return torch.det(graph.cell).abs().item()
[docs] def number_of_atoms(graph: AtomicGraph) -> int: """ Get the number of atoms in the ``graph``. """ return graph.Z.shape[0]
[docs] def number_of_edges(graph: AtomicGraph) -> int: """ Get the number of edges in the ``graph``. """ return graph.neighbour_list.shape[1]
[docs] def has_cell(graph: AtomicGraph) -> bool: """ Does ``graph`` represent a structure with a defined unit cell? """ return not torch.allclose(graph.cell, torch.zeros_like(graph.cell))
[docs] def neighbour_vectors(graph: AtomicGraph) -> torch.Tensor: """ Get the vector between each pair of atoms specified in the ``graph``'s ``"neighbour_list"`` property, respecting periodic boundary conditions where present. """ # to simplify the logic below, we'll expand # a single graph into a batch of one batch: torch.Tensor = torch.zeros_like(graph.Z) cell: torch.Tensor = graph.cell.unsqueeze(0) # torchscript annoying-ness: graph_batch = graph.batch if graph_batch is not None: cell = graph.cell batch = graph_batch # avoid tuple de-structuring to keep torchscript happy i, j = graph.neighbour_list[0], graph.neighbour_list[1] # (E,) if i.shape[0] == 0: return torch.zeros(0, 3, device=graph.R.device) cell_per_edge = cell[batch[i]] # (E, 3, 3) distance_offsets = torch.einsum( "kl,klm->km", graph.neighbour_cell_offsets.to(cell_per_edge.dtype), cell_per_edge, ) # (E, 3) neighbour_positions = graph.R[j] + distance_offsets # (E, 3) return neighbour_positions - graph.R[i] # (E, 3)
[docs] def neighbour_distances(graph: AtomicGraph) -> torch.Tensor: """ Get the distance between each pair of atoms specified in the ``graph``'s ``neighbour_list`` property, respecting periodic boundary conditions where present. """ return torch.linalg.norm(neighbour_vectors(graph), dim=-1)
[docs] def number_of_structures(graph: AtomicGraph) -> int: """ Get the number of structures in the ``graph``. """ # torchscript annoying-ness: graph_ptr = graph.ptr if graph_ptr is None: return 1 return graph_ptr.shape[0] - 1
[docs] def structure_sizes(batch: AtomicGraph) -> torch.Tensor: """ Get the number of atoms in each structure in the ``batch``, of shape ``(S,)`` where ``S`` is the number of structures. Parameters ---------- batch The batch to get the structure sizes for. Examples -------- >>> len(graphs) 3 >>> [number_of_atoms(g) for g in graphs] [3, 4, 5] >>> structure_sizes(to_batch(graphs)) tensor([3, 4, 5]) """ # torchscript annoying-ness: graph_ptr = batch.ptr if graph_ptr is None: return torch.scalar_tensor(number_of_atoms(batch)) return graph_ptr[1:] - graph_ptr[:-1]
[docs] def number_of_neighbours( graph: AtomicGraph, include_central_atom: bool = True, ) -> torch.Tensor: """ Get a tensor, ``T``, of shape ``(N,)``, where ``N`` is the number of atoms in the ``graph``, such that ``T[i]`` gives the number of neighbours of atom ``i``. If ``include_central_atom`` is ``True``, then the central atom is included in the count. Parameters ---------- graph The graph to get the number of neighbours for. include_central_atom Whether to include the central atom in the count. """ return sum_over_neighbours( torch.ones_like(graph.neighbour_list[0]), graph, ) + int(include_central_atom)
[docs] def available_properties(graph: AtomicGraph) -> List[PropertyKey]: """Get the labels that are available on the ``graph``.""" return [cast(PropertyKey, k) for k in graph.properties]
############################### ACTIONS ###############################
[docs] def is_local_property(x: torch.Tensor, graph: AtomicGraph) -> bool: """ Is the property ``x`` local to each atom in the ``graph``? Parameters ---------- x The property to check. graph The graph to check the property for. """ return len(x.shape) > 0 and x.shape[0] == number_of_atoms(graph)
[docs] def trim_edges(graph: AtomicGraph, cutoff: float) -> AtomicGraph: """ Return a new graph with edges trimmed to be no longer than the ``cutoff``. Leaves the original graph unchanged. Parameters ---------- graph The graph to trim the edges of. cutoff The maximum distance between atoms to keep the edge. """ existing_cutoff = graph.cutoff if existing_cutoff + 1e-5 < cutoff: warnings.warn( f"Graph already has a cutoff of {existing_cutoff} which is " f"less than the requested cutoff of {cutoff}.", stacklevel=2, ) return graph elif existing_cutoff == cutoff: return graph distances = neighbour_distances(graph) mask = distances <= cutoff neighbour_list = graph.neighbour_list[:, mask] neighbour_cell_offsets = graph.neighbour_cell_offsets[mask, :] # can't use _replace here due to TorchScript return AtomicGraph( Z=graph.Z, R=graph.R, cell=graph.cell, neighbour_list=neighbour_list, neighbour_cell_offsets=neighbour_cell_offsets, properties=graph.properties, other=graph.other, cutoff=cutoff, batch=graph.batch, ptr=graph.ptr, )
[docs] def sum_over_central_atom_index( p: torch.Tensor, central_atom_index: torch.Tensor, graph: AtomicGraph, ) -> torch.Tensor: r""" Efficient, shape-preserving sum of a property, :math:`p`, defined over a ``central_atom_index``, to get a per-atom property, :math:`P`, such that: .. code-block:: python # i in central_atom_index P[i] == torch.sum(p[central_atom_index == i], dim=0) # i not in central_atom_index P[i] == torch.zeros_like(p[0]) .. seealso:: :func:`sum_over_neighbours` for the explicit case where ``p`` is a per-edge property. Parameters ---------- p The property to sum, of shape ``(Y, ...)``. central_atom_index The central atoms relevant to each element of ``p``, of shape ``(Y,)``. graph The graph to sum the property for. Returns ------- P: torch.Tensor The summed property, of shape ``(N, ...)``. """ N = number_of_atoms(graph) # optimised implementations for common cases if p.dim() == 1: zeros = torch.zeros(N, dtype=p.dtype, device=p.device) return zeros.scatter_add(0, central_atom_index, p) elif p.dim() == 2: C = p.shape[1] zeros = torch.zeros(N, C, dtype=p.dtype, device=p.device) return zeros.scatter_add( 0, central_atom_index.unsqueeze(1).expand(-1, C), p, ) shape = (N,) + p.shape[1:] zeros = torch.zeros(shape, dtype=p.dtype, device=p.device) if p.shape[0] == 0: # return all zeros if there are no atoms return zeros # create `index`, where index.shape = p.shape # and (index[e] == central_atoms[e]).all() ones = torch.ones_like(p) index = left_aligned_mul(ones, central_atom_index).long() return zeros.scatter_add(0, index, p)
[docs] def sum_over_neighbours(p: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: r""" Shape-preserving sum over neighbours of a per-edge property, :math:`p_{ij}`, to get a per-atom property, :math:`P_i`: .. math:: P_i = \sum_{j \in \mathcal{N}_i} p_{ij} where: * :math:`\mathcal{N}_i` is the set of neighbours of atom :math:`i`. * :math:`p_{ij}` is the property of the edge between atoms :math:`i` and :math:`j`. * :math:`p` is of shape :code:`(E, ...)` and :math:`P` is of shape :code:`(N, ...)` where :math:`E` is the number of edges and :math:`N` is the number of atoms. :code:`...` denotes any number of additional dimensions, including none. * :math:`P_i` = 0 if :math:`|\mathcal{N}_i| = 0`. Parameters ---------- p The per-edge property to sum. graph The graph to sum the property for. """ return sum_over_central_atom_index(p, graph.neighbour_list[0], graph)
[docs] def sum_per_structure(x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: r""" Shape-preserving sum of a per-atom property, :math:`p`, to get a per-structure property, :math:`P`: If a single structure, containing ``N`` atoms, is used, then :math:`P = \sum_i p_i`, where: * :math:`p_i` is of shape ``(N, ...)`` * :math:`P` is of shape ``(...)`` * ``...`` denotes any number of additional dimensions, including ``None``. If a batch of ``S`` structures, containing a total of ``N`` atoms, is used, then :math:`P_k = \sum_{k \in K} p_k`, where: * :math:`K` is the collection of all atoms in structure :math:`k` * :math:`p_i` is of shape ``(N, ...)`` * :math:`P` is of shape ``(S, ...)`` * ``...`` denotes any number of additional dimensions, including ``None``. Parameters ---------- x The per-atom property to sum. graph The graph to sum the property for. Examples -------- Single graph case: >>> import torch >>> from ase.build import molecule >>> from graph_pes.atomic_graph import sum_per_structure, AtomicGraph >>> water = molecule("H2O") >>> graph = AtomicGraph.from_ase(water, cutoff=1.5) >>> # summing over a vector gives a scalar >>> sum_per_structure(torch.ones(3), graph) tensor(3.) >>> # summing over higher order tensors gives a tensor >>> sum_per_structure(torch.ones(3, 2, 3), graph).shape torch.Size([2, 3]) Batch case: >>> import torch >>> from ase.build import molecule >>> from graph_pes.atomic_graph import sum_per_structure, AtomicGraph, to_batch >>> water = molecule("H2O") >>> graph = AtomicGraph.from_ase(water, cutoff=1.5) >>> batch = to_batch([graph, graph]) >>> batch AtomicGraphBatch(structures: 2, atoms: 6, edges: 8, has_cell: False) >>> # summing over a vector gives a tensor >>> sum_per_structure(torch.ones(6), graph) tensor([3., 3.]) >>> # summing over higher order tensors gives a tensor >>> sum_per_structure(torch.ones(6, 3, 4), graph).shape torch.Size([2, 3, 4]) """ # noqa: E501 # torchscript annoying-ness: graph_batch = graph.batch if graph_batch is not None: shape = (number_of_structures(graph),) + x.shape[1:] zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) return zeros.scatter_add(0, graph_batch, x) else: return x.sum(dim=0)
[docs] def index_over_neighbours(x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: """ Index a per-atom property, :math:`x`, over the neighbours of each atom in the ``graph``. """ return x[graph.neighbour_list[1]]
[docs] def divide_per_atom(x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: r""" Divide a per-structure property, :math:`X`, by the number of atoms in each structure to get a per-atom property, :math:`x`: .. math:: x_i = \frac{X_k}{N_k} where: * :math:`X` is of shape ``(S, ...)`` * :math:`x` is of shape ``(N, ...)`` * :math:`S` is the number of structures * :math:`N` is the number of atoms """ return left_aligned_div(x, structure_sizes(graph))