Source code for graph_pes.interfaces._mattersim

import torch

from graph_pes import AtomicGraph, GraphPESModel
from graph_pes.atomic_graph import (
    PropertyKey,
    neighbour_distances,
    neighbour_vectors,
    number_of_edges,
    sum_per_structure,
)
from graph_pes.utils.threebody import (
    angle_spanned_by,
    triplet_edge_pairs,
)


class MatterSim_M3Gnet_Wrapper(GraphPESModel):
    def __init__(self, model: torch.nn.Module):
        super().__init__(
            cutoff=model.model_args["cutoff"],  # type: ignore
            implemented_properties=["local_energies"],
        )
        self.model = model
        self.threebody_cutoff = torch.tensor(
            model.model_args["threebody_cutoff"]
        )  # type: ignore

    def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
        # pre-compute
        edge_lengths = neighbour_distances(graph)  # (E)
        edge_pairs = triplet_edge_pairs(graph, self.threebody_cutoff.item())
        triplets_per_leading_edge = count_number_of_triplets_per_leading_edge(
            edge_pairs, graph
        )
        r_ik = edge_lengths[edge_pairs[:, 1]]
        v = neighbour_vectors(graph)
        v_ij = v[edge_pairs[:, 0]]
        v_ik = v[edge_pairs[:, 1]]
        angle = angle_spanned_by(v_ij, v_ik)

        num_atoms = sum_per_structure(
            torch.ones_like(graph.Z), graph
        ).unsqueeze(-1)

        # num_bonds is of shape (n_structures,) such that
        # num_bonds[i] = sum(graph.neighbour_list[0] == i)
        bonds_per_atom = torch.zeros_like(graph.Z)
        bonds_per_atom = bonds_per_atom.scatter_add(
            dim=0,
            index=graph.neighbour_list[0],
            src=torch.ones_like(graph.neighbour_list[0]),
        )
        num_bonds = sum_per_structure(bonds_per_atom, graph).unsqueeze(-1)

        three_body_indices = edge_pairs
        num_triple_ij = triplets_per_leading_edge.unsqueeze(-1)

        # use the forward pass of M3Gnet
        atom_attr = self.model.atom_embedding(self.model.one_hot_atoms(graph.Z))
        edge_attr = self.model.rbf(edge_lengths)
        edge_attr_zero = edge_attr
        edge_attr = self.model.edge_encoder(edge_attr)
        three_basis = self.model.sbf(r_ik, angle)

        for conv in self.model.graph_conv:
            atom_attr, edge_attr = conv(
                atom_attr,
                edge_attr,
                edge_attr_zero,
                graph.neighbour_list,
                three_basis,
                three_body_indices,
                edge_lengths.unsqueeze(-1),
                num_bonds,
                num_triple_ij,
                num_atoms,
            )

        local_energies = self.model.final(atom_attr).view(-1)
        local_energies = self.model.normalizer(local_energies, graph.Z)

        return {"local_energies": local_energies}


[docs] def mattersim(load_path: str = "mattersim-v1.0.0-1m") -> GraphPESModel: """ Load a ``mattersim`` model from a checkpoint file, and convert it to a :class:`~graph_pes.GraphPESModel` on the CPU. Parameters ---------- load_path: str The path to the ``mattersim`` checkpoint file. Expected to be one of ``mattersim-v1.0.0-1m`` or ``mattersim-v1.0.0-5m`` currently. """ from mattersim.forcefield.potential import Potential model = Potential.from_checkpoint( # type: ignore load_path, load_training_state=False, # only load the model device="cpu", # manage the device ourself later ).model return MatterSim_M3Gnet_Wrapper(model)
@torch.no_grad() def count_number_of_triplets_per_leading_edge( edge_pairs: torch.Tensor, graph: AtomicGraph, ): """ Return ``T`` of shape ``(E,)`` where ``T[e]`` is the number of edge pairs that have edge number ``e`` as the first edge in the pair. Parameters ---------- edge_pairs: torch.Tensor A ``(E, 2)`` shaped tensor indicating pairs of edges that form a triplet ``(i, j, k)`` (see :func:`triplet_edge_pairs`). graph: AtomicGraph The graph from which the edge pairs were derived. Returns ------- triplets_per_edge: torch.Tensor A ``(E,)`` shaped tensor where ``triplets_per_edge[e]`` is the number of edge pairs that have edge ``e`` as the first edge in the pair. """ triplets_per_edge = torch.zeros( number_of_edges(graph), device=graph.R.device, dtype=torch.long ) return triplets_per_edge.scatter_add( dim=0, index=edge_pairs[:, 0], src=torch.ones_like(edge_pairs[:, 0]), )