Source code for graph_pes.interfaces._mattersim

import torch

from graph_pes import AtomicGraph, GraphPESModel
from graph_pes.atomic_graph import (
    PropertyKey,
    neighbour_distances,
    neighbour_vectors,
    number_of_atoms,
    number_of_edges,
    sum_per_structure,
)
from graph_pes.utils.threebody import angle_spanned_by


def threebody_edge_pairs(graph: AtomicGraph, three_body_cutoff: float):
    edge_indexes = torch.arange(number_of_edges(graph), device=graph.R.device)
    three_body_mask = neighbour_distances(graph) < three_body_cutoff
    relevant_edge_indexes = edge_indexes[three_body_mask]
    relevant_central_atoms = graph.neighbour_list[0][relevant_edge_indexes]
    edge_pairs = []
    for i in range(number_of_atoms(graph)):
        mask = relevant_central_atoms == i
        masked_edge_indexes = relevant_edge_indexes[mask]
        # number of edges of distance <= three_body_cutoff
        # that have i as a central atom
        N = masked_edge_indexes.shape[0]
        _idx = torch.cartesian_prod(
            torch.arange(N),
            torch.arange(N),
        )  # (N**2, 2)
        _idx = _idx[_idx[:, 0] != _idx[:, 1]]  # (N**2 - N, 2)
        pairs_for_i = masked_edge_indexes[_idx]
        edge_pairs.append(pairs_for_i)
    return torch.cat(edge_pairs)


class MatterSim_M3Gnet_Wrapper(GraphPESModel):
    def __init__(self, model: torch.nn.Module):
        super().__init__(
            cutoff=model.model_args["cutoff"],  # type: ignore
            implemented_properties=["local_energies"],
            three_body_cutoff=model.model_args["threebody_cutoff"],  # type: ignore
        )
        self.model = model

    def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
        # pre-compute
        edge_lengths = neighbour_distances(graph)  # (E)
        _3b_edge_pairs = threebody_edge_pairs(
            graph, self.three_body_cutoff.item()
        )
        triplets_per_leading_edge = count_number_of_triplets_per_leading_edge(
            _3b_edge_pairs, graph
        )
        r_ik = edge_lengths[_3b_edge_pairs[:, 1]]
        v = neighbour_vectors(graph)
        v_ij = v[_3b_edge_pairs[:, 0]]
        v_ik = v[_3b_edge_pairs[:, 1]]
        angle = angle_spanned_by(v_ij, v_ik)

        angle = angle_spanned_by(v_ij, v_ik)

        num_atoms = sum_per_structure(
            torch.ones_like(graph.Z), graph
        ).unsqueeze(-1)

        # num_bonds is of shape (n_structures,) such that
        # num_bonds[i] = sum(graph.neighbour_list[0] == i)
        bonds_per_atom = torch.zeros_like(graph.Z)
        bonds_per_atom = bonds_per_atom.scatter_add(
            dim=0,
            index=graph.neighbour_list[0],
            src=torch.ones_like(graph.neighbour_list[0]),
        )
        num_bonds = sum_per_structure(bonds_per_atom, graph).unsqueeze(-1)

        num_triple_ij = triplets_per_leading_edge.unsqueeze(-1)

        # use the forward pass of M3Gnet
        atom_attr = self.model.atom_embedding(self.model.one_hot_atoms(graph.Z))
        edge_attr = self.model.rbf(edge_lengths)
        edge_attr_zero = edge_attr
        edge_attr = self.model.edge_encoder(edge_attr)
        three_basis = self.model.sbf(r_ik, angle)

        for conv in self.model.graph_conv:
            atom_attr, edge_attr = conv(
                atom_attr,
                edge_attr,
                edge_attr_zero,
                graph.neighbour_list,
                three_basis,
                _3b_edge_pairs,
                edge_lengths.unsqueeze(-1),
                num_bonds,
                num_triple_ij,
                num_atoms,
            )

        local_energies = self.model.final(atom_attr).view(-1)
        local_energies = self.model.normalizer(local_energies, graph.Z)

        return {"local_energies": local_energies}


[docs] def mattersim(load_path: str = "mattersim-v1.0.0-1m") -> GraphPESModel: """ Load a ``mattersim`` model from a checkpoint file, and convert it to a :class:`~graph_pes.GraphPESModel` on the CPU. Parameters ---------- load_path: str The path to the ``mattersim`` checkpoint file. Expected to be one of ``mattersim-v1.0.0-1m`` or ``mattersim-v1.0.0-5m`` currently. """ from mattersim.forcefield.potential import Potential model = Potential.from_checkpoint( # type: ignore load_path, load_training_state=False, # only load the model device="cpu", # manage the device ourself later ).model return MatterSim_M3Gnet_Wrapper(model)
@torch.no_grad() def count_number_of_triplets_per_leading_edge( edge_pairs: torch.Tensor, graph: AtomicGraph, ): """ Return ``T`` of shape ``(E,)`` where ``T[e]`` is the number of edge pairs that have edge number ``e`` as the first edge in the pair. Parameters ---------- edge_pairs: torch.Tensor A ``(E, 2)`` shaped tensor indicating pairs of edges that form a triplet ``(i, j, k)`` (see :func:`triplet_edge_pairs`). graph: AtomicGraph The graph from which the edge pairs were derived. Returns ------- triplets_per_edge: torch.Tensor A ``(E,)`` shaped tensor where ``triplets_per_edge[e]`` is the number of edge pairs that have edge ``e`` as the first edge in the pair. """ triplets_per_edge = torch.zeros( number_of_edges(graph), device=graph.R.device, dtype=torch.long ) return triplets_per_edge.scatter_add( dim=0, index=edge_pairs[:, 0], src=torch.ones_like(edge_pairs[:, 0]), )