Source code for graph_pes.models.eddp

from __future__ import annotations

import torch
from ase.data import atomic_numbers

from ..atomic_graph import (
    AtomicGraph,
    PropertyKey,
    neighbour_distances,
    sum_over_central_atom_index,
)
from ..graph_pes_model import GraphPESModel
from ..utils.misc import uniform_repr
from ..utils.nn import MLP, AtomicOneHot, UniformModuleList
from ..utils.threebody import triplet_edges


class RadialExpansion(torch.nn.Module):
    def __init__(
        self,
        cutoff: float = 5.0,
        features: int = 8,
        max_power: float = 8,
        learnable_powers: bool = False,
    ):
        super().__init__()
        self.cutoff = cutoff
        beta = (max_power / 2) ** (1.0 / (features - 1))
        powers = torch.tensor(
            [2 * (beta**i) for i in range(features)], dtype=torch.float
        )
        if learnable_powers:
            self.exponents = torch.nn.Parameter(powers)
        else:
            self.register_buffer("exponents", powers)

    def forward(
        self,
        r: torch.Tensor,  # of shape (E,)
    ) -> torch.Tensor:  # of shape (E, F)
        # apply the linear function
        f = torch.clamp(2 * (1 - r / self.cutoff), min=0)  # (E,)
        # repeat f for each exponent
        f = f.view(-1, 1).repeat(1, self.exponents.shape[0])
        # raise to the power of the exponents
        f = f ** torch.clamp(self.exponents, min=2)

        return f

    @torch.no_grad()
    def __repr__(self):
        return uniform_repr(self.__class__.__name__, cutoff=self.cutoff)


class TwoBodyDescriptor(torch.nn.Module):
    """
    Two-body term of the EDDP potential.
    """

    def __init__(
        self,
        Z1: int,
        Z2: int,
        cutoff: float = 5.0,
        features: int = 8,
        max_power: float = 8,
        learnable_powers: bool = False,
    ):
        super().__init__()
        self.Z1 = Z1
        self.Z2 = Z2

        self.expansion = RadialExpansion(
            cutoff, features, max_power, learnable_powers
        )

    def forward(self, r: torch.Tensor, graph: AtomicGraph) -> torch.Tensor:
        # select the relevant edges
        i = graph.neighbour_list[0]
        j = graph.neighbour_list[1]
        mask = (graph.Z[i] == self.Z1) & (graph.Z[j] == self.Z2)
        r = r[mask]

        # expand each edge
        f_edge = self.expansion(r)

        # sum over central atom
        f_atom = sum_over_central_atom_index(f_edge, i[mask], graph)
        return f_atom


class ThreeBodyDescriptor(torch.nn.Module):
    """
    Three-body term of the EDDP potential.
    """

    def __init__(
        self,
        Z1: int,
        Z2: int,
        Z3: int,
        cutoff: float = 5.0,
        features: int = 8,
        max_power: float = 8,
        learnable_powers: bool = False,
    ):
        super().__init__()
        self.Z1 = Z1
        self.Z2 = Z2
        self.Z3 = Z3
        self.cutoff = cutoff

        self.central_atom_expansion = RadialExpansion(
            cutoff, features, max_power, learnable_powers
        )
        self.neighbour_atom_expansion = RadialExpansion(
            cutoff, features, max_power, learnable_powers
        )

    def forward(
        self, i, j, k, r_ij, r_ik, r_jk, graph: AtomicGraph
    ) -> torch.Tensor:
        mask = (
            (graph.Z[i] == self.Z1)
            & (graph.Z[j] == self.Z2)
            & (graph.Z[k] == self.Z3)
        )
        r_ij = r_ij[mask]
        r_ik = r_ik[mask]
        r_jk = r_jk[mask]

        e_ij = self.central_atom_expansion(r_ij)
        e_ik = self.central_atom_expansion(r_ik)
        e_jk = self.neighbour_atom_expansion(r_jk)

        E = r_ij.shape[0]
        F = e_ij.shape[1]
        prod = (e_ij * e_ik)[:, None, :] * e_jk[:, :, None]
        prod = prod.view(E, F * F)

        # sum over central atoms
        return sum_over_central_atom_index(prod, i[mask], graph)


[docs] class EDDP(GraphPESModel): r""" The Ephemeral Data Derived Potential (EDDP) architecture, from `Ephemeral data derived potentials for random structure search. Phys. Rev. B 106, 014102 <https://doi.org/10.1103/PhysRevB.106.014102>`_. This model makes local energy predictions by featurising the local atomic environment and passing the result through an MLP: .. math:: \varepsilon_i = \text{MLP}(\mathbf{F}_i) where :math:`\mathbf{F}_i` is a a concatenation of one-, two- and three-body features: .. math:: \mathbf{F}_i = \mathbf{F}_i^{(1)} \oplus \mathbf{F}_i^{(2)} \oplus \mathbf{F}_i^{(3)} as per Eq. 14. The m'th component of the two-body feature is given by: .. math:: \mathbf{F}_i^{(2)}[m] = \sum_{j \in \mathcal{N}(i)} f(r_{ij})^{p_m} where :math:`r_{ij}` is the distance between atoms :math:`i` and :math:`j`, :math:`\mathcal{N}(i)` is the set of neighbours of atom :math:`i`, :math:`p_m \geq 2` is some power, and :math:`f(r_{ij})` is defined as: .. math:: f(r_{ij}) = \begin{cases} 2 (1 - r_{ij} / r_{\text{cutoff}}) & \text{if } r_{ij} \leq r_{\text{cutoff}} \\ 0 & \text{otherwise} \end{cases} as per Eq. 7. Finally, the :math:`(m, o)`'th component of the three-body feature is given by: .. math:: \mathbf{F}_i^{(3)}[m, o] = \sum_{j \in \mathcal{N}(i)} \sum_{k > j \in \mathcal{N}(i)} f(r_{ij})^{p_m} f(r_{ik})^{p_m} f(r_{jk})^{p_o} If you use this model, please cite the original paper: .. code-block:: bibtex @article{Pickard-22-07, title = { Ephemeral Data Derived Potentials for Random Structure Search }, author = {Pickard, Chris J.}, year = {2022}, journal = {Physical Review B}, volume = {106}, number = {1}, pages = {014102}, doi = {10.1103/PhysRevB.106.014102}, } Parameters ---------- elements: The elements the model will be applied to (e.g. ``["C", "H"]``). cutoff: The maximum distance between pairs of atoms to consider their interaction. features: The number of features used to describe the two body interactions. max_power: The maximum power of the interaction expansion for two body terms. mlp_width: The width of the MLP (i.e. number of neurons in each hidden layer). mlp_layers: The number of hidden layers in the MLP. activation: The activation function to use in the MLP. three_body_cutoff: The radial cutoff for three body interactions. If ``None``, the same as ``cutoff``. three_body_features: The number of features used to describe the three body interactions. If ``None``, the same as for the two body terms. three_body_max_power: The maximum power of the interaction expansion for three body terms. If ``None``, the same as for the two body terms. learnable_powers: If ``True``, the powers of the interaction expansion are learnable. If ``False``, the powers are fixed to the values given by ``max_power``. """ # noqa: E501 def __init__( self, elements: list[str], cutoff: float = 5.0, features: int = 8, max_power: float = 8, mlp_width: int = 16, mlp_layers: int = 1, activation: str = "CELU", three_body_cutoff: float | None = None, three_body_features: int | None = None, three_body_max_power: float | None = None, learnable_powers: bool = False, ): if three_body_cutoff is None: three_body_cutoff = cutoff if three_body_features is None: three_body_features = features if three_body_max_power is None: three_body_max_power = max_power super().__init__( cutoff=max(cutoff, three_body_cutoff), implemented_properties=["local_energies"], three_body_cutoff=three_body_cutoff, ) self.elements = elements Zs = [atomic_numbers[Z] for Z in elements] # one body terms self.one_hot = AtomicOneHot(elements) # two body terms self.Z_pairs = [(Z1, Z2) for Z1 in Zs for Z2 in Zs] self.two_body_descriptors = UniformModuleList( [ TwoBodyDescriptor( Z1, Z2, cutoff=cutoff, features=features, max_power=max_power, learnable_powers=learnable_powers, ) for Z1, Z2 in self.Z_pairs ] ) # three body terms self.Z_triplets = [ (Z1, Z2, Z3) for Z1 in Zs for Z2 in Zs for Z3 in Zs if Z2 <= Z3 # skip identical triplets (A,B,C) and (A,C,B) ] self.three_body_descriptors = UniformModuleList( [ ThreeBodyDescriptor( Z1, Z2, Z3, cutoff=three_body_cutoff, features=three_body_features, max_power=three_body_max_power, learnable_powers=learnable_powers, ) for Z1, Z2, Z3 in self.Z_triplets ] ) # read out head input_features = ( len(elements) + len(self.two_body_descriptors) * features + len(self.three_body_descriptors) * three_body_features**2 ) layers = [input_features] + [mlp_width] * mlp_layers + [1] self.mlp = MLP(layers, activation=activation) self.shifts = torch.nn.Parameter(torch.zeros(input_features)) self.scales = torch.nn.Parameter(torch.ones(input_features)) def featurise(self, graph: AtomicGraph) -> torch.Tensor: # one body terms central_atom_features = [self.one_hot(graph.Z)] # two body terms rs = neighbour_distances(graph) for descriptor in self.two_body_descriptors: central_atom_features.append(descriptor(rs, graph)) # three body terms i, j, k, r_ij, r_ik, r_jk = triplet_edges( graph, self.three_body_cutoff.item() ) for descriptor in self.three_body_descriptors: central_atom_features.append( descriptor(i, j, k, r_ij, r_ik, r_jk, graph) ) # concatenate all features f = torch.cat(central_atom_features, dim=1) f = (f - self.shifts) / self.scales return f def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]: features = self.featurise(graph) return {"local_energies": self.mlp(features).view(-1)} def pre_fit(self, graphs: AtomicGraph) -> None: X = self.featurise(graphs) _mean = X.mean(dim=0) _std = X.std(dim=0) # replace nans _std[torch.isnan(_std) | (_std < 1e-6)] = 1 self.shifts.data = _mean self.scales.data = _std