Source code for graph_pes.models.e3nn.mace

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Final, cast

import torch
import torch.fx
from e3nn import o3

from graph_pes.atomic_graph import (
    DEFAULT_CUTOFF,
    AtomicGraph,
    PropertyKey,
    index_over_neighbours,
    neighbour_distances,
    neighbour_vectors,
)
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.models.components.aggregation import (
    NeighbourAggregation,
    NeighbourAggregationMode,
)
from graph_pes.models.components.distances import (
    DistanceExpansion,
    PolynomialEnvelope,
    get_distance_expansion,
)
from graph_pes.models.components.scaling import LocalEnergiesScaler
from graph_pes.models.e3nn.mace_utils import (
    Contraction,
    ContractionConfig,
    UnflattenIrreps,
    parse_irreps,
)
from graph_pes.models.e3nn.utils import (
    LinearReadOut,
    NonLinearReadOut,
    ReadOut,
    SphericalHarmonics,
    as_irreps,
    build_limited_tensor_product,
    to_full_irreps,
)
from graph_pes.utils.nn import (
    MLP,
    AtomicOneHot,
    HaddamardProduct,
    MLPConfig,
    PerElementEmbedding,
    UniformModuleList,
)


class MACEInteraction(torch.nn.Module):
    """
    The MACE interaction block.

    Generates new node embeddings from the old node embeddings and the
    spherical harmonic expansion and mangitudes of the neighbour vectors.
    """

    def __init__(
        self,
        # input nodes
        irreps_in: list[o3.Irrep],
        nodes: NodeDescription,
        # input edges
        sph_harmonics: o3.Irreps,
        radial_basis_features: int,
        mlp: MLPConfig,
        # other
        aggregation: NeighbourAggregationMode,
        mix_attributes: bool,
    ):
        super().__init__()

        irreps_out = [ir for _, ir in sph_harmonics]

        features_in = as_irreps([(nodes.channels, ir) for ir in irreps_in])
        self.pre_linear = o3.Linear(
            features_in,
            features_in,
            internal_weights=True,
            shared_weights=True,
        )

        self.tp = build_limited_tensor_product(
            features_in,
            sph_harmonics,
            irreps_out,
        )
        mid_features = self.tp.irreps_out.simplify()
        assert all(ir in mid_features for ir in irreps_out)

        self.weight_generator = MLP.from_config(
            mlp,
            input_features=radial_basis_features,
            output_features=self.tp.weight_numel,
            bias=False,
        )

        features_out = as_irreps(
            [(nodes.channels, ir) for (_, ir) in sph_harmonics]
        )
        self.post_linear = o3.Linear(
            mid_features,
            features_out,
            internal_weights=True,
            shared_weights=True,
        )

        self.aggregator = NeighbourAggregation.parse(aggregation)

        if mix_attributes:
            self.attribute_mixer = o3.FullyConnectedTensorProduct(
                irreps_in1=features_out,
                irreps_in2=o3.Irreps(f"{nodes.attributes}x0e"),
                irreps_out=features_out,
            )
        else:
            self.attribute_mixer = None

        self.reshape = UnflattenIrreps(irreps_out, nodes.channels)

        # book-keeping
        self.irreps_in = features_in
        self.irreps_out = features_out

    def forward(
        self,
        node_features: torch.Tensor,
        node_attributes: torch.Tensor,
        sph_harmonics: torch.Tensor,
        radial_basis: torch.Tensor,
        graph: AtomicGraph,
    ) -> torch.Tensor:
        # pre-linear
        node_features = self.pre_linear(node_features)  # (N, a)

        # tensor product: mix node and edge features
        neighbour_features = index_over_neighbours(
            node_features, graph
        )  # (E, a)
        weights = self.weight_generator(radial_basis)  # (E, b)
        messages = self.tp(
            neighbour_features,
            sph_harmonics,
            weights,
        )  # (E, c)

        # aggregate
        total_message = self.aggregator(messages, graph)  # (N, c)

        # post-linear
        node_features = self.post_linear(total_message)  # (N, d)

        if self.attribute_mixer is not None:
            node_features = self.attribute_mixer(node_features, node_attributes)

        return self.reshape(node_features)  # (N, channels, d')

    # type hints for mypy
    def __call__(
        self,
        node_features: torch.Tensor,
        node_attributes: torch.Tensor,
        sph_harmonics: torch.Tensor,
        radial_basis: torch.Tensor,
        graph: AtomicGraph,
    ) -> torch.Tensor:
        return super().__call__(
            node_features,
            node_attributes,
            sph_harmonics,
            radial_basis,
            graph,
        )


@dataclass
class NodeDescription:
    channels: int
    attributes: int
    hidden_features: list[o3.Irrep]

    def hidden_irreps(self) -> o3.Irreps:
        return to_full_irreps(self.channels, self.hidden_features)


class MACELayer(torch.nn.Module):
    def __init__(
        self,
        irreps_in: list[o3.Irrep],
        nodes: NodeDescription,
        correlation: int,
        sph_harmonics: o3.Irreps,
        radial_basis_features: int,
        mlp: MLPConfig,
        use_sc: bool,
        aggregation: NeighbourAggregationMode,
        residual: bool,
        final_layer: bool,
    ):
        super().__init__()

        self.interaction = MACEInteraction(
            irreps_in=irreps_in,
            nodes=nodes,
            sph_harmonics=sph_harmonics,
            radial_basis_features=radial_basis_features,
            mlp=mlp,
            aggregation=aggregation,
            # only mix attributes in the interaction block
            # if we **aren't** using a residual connection
            mix_attributes=not residual,
        )
        actual_mid_features = [ir for _, ir in self.interaction.irreps_out]

        output_features = o3.Irreps(
            nodes.hidden_irreps()
            if not final_layer
            else o3.Irreps(f"{nodes.channels}x0e")
        )
        self.contractions = UniformModuleList(
            [
                Contraction(
                    config=ContractionConfig(
                        num_features=nodes.channels,
                        n_node_attributes=nodes.attributes,
                        irrep_s_in=actual_mid_features,
                        irrep_out=target_irrep,
                    ),
                    correlation=correlation,
                )
                for target_irrep in [o.ir for o in output_features]
            ]
        )

        if use_sc and residual:
            # links input features to output features via a tensor product
            self.residual_update = o3.FullyConnectedTensorProduct(
                irreps_in1=[(nodes.channels, ir) for ir in irreps_in],
                irreps_in2=o3.Irreps(f"{nodes.attributes}x0e"),
                irreps_out=output_features,
            )
        else:
            self.residual_update = None

        # update the hidden features from the interaction block
        # and target the output features
        self.post_linear = o3.Linear(
            output_features,
            output_features,
            internal_weights=True,
            shared_weights=True,
        )

        # book-keeping
        self.irreps_in = irreps_in
        self.irreps_out: o3.Irreps = output_features  # type: ignore

    def forward(
        self,
        node_features: torch.Tensor,
        node_attributes: torch.Tensor,
        sph_harmonics: torch.Tensor,
        radial_basis: torch.Tensor,
        graph: AtomicGraph,
    ) -> torch.Tensor:
        # A MACE layer operates on:
        # - node features with multiplicity M, e.g. M=16: 16x0e + 16x1o
        # - node attributes with multiplicity A e.g. A=5: 5x0e
        # - spherical harmonics up to l_max, e.g. l_max=2: 1x0e + 1x1o + 1x2e

        # interact
        internal_node_features = self.interaction(
            node_features,
            node_attributes,
            sph_harmonics,
            radial_basis,
            graph,
        )  # (N, M, irreps)

        # contract using the contractions directly
        contracted_features = torch.cat(
            [
                contraction(internal_node_features, node_attributes)
                for contraction in self.contractions
            ],
            dim=-1,
        )  # (N, irreps_out)

        # residual update
        if self.residual_update is not None:
            update = self.residual_update(
                node_features,
                node_attributes,
            )  # (N, irreps_out)
            contracted_features = contracted_features + update

        # linear update
        node_features = self.post_linear(contracted_features)  # (N, irreps_out)

        return node_features


class _BaseMACE(GraphPESModel):
    def __init__(
        self,
        # radial things
        cutoff: float,
        n_radial: int,
        radial_expansion: type[DistanceExpansion] | str,
        weights_mlp: MLPConfig,
        # node things
        nodes: NodeDescription,
        node_attribute_generator: Callable[[torch.Tensor], torch.Tensor],
        # message passing
        layers: int,
        l_max: int,
        correlation: int,
        neighbour_aggregation: NeighbourAggregationMode,
        use_self_connection: bool,
        # readout
        readout_width: int,
    ):
        super().__init__(
            cutoff=cutoff,
            implemented_properties=["local_energies"],
        )

        if o3.Irrep("0e") not in nodes.hidden_features:
            raise ValueError("MACE requires a `0e` hidden feature")

        # radial things
        sph_harmonics = cast(o3.Irreps, o3.Irreps.spherical_harmonics(l_max))
        self.spherical_harmonics = SphericalHarmonics(
            sph_harmonics,
            normalize=True,
            normalization="component",
        )
        if isinstance(radial_expansion, str):
            radial_expansion = get_distance_expansion(radial_expansion)
        self.radial_expansion = HaddamardProduct(
            radial_expansion(
                n_features=n_radial, cutoff=cutoff, trainable=True
            ),
            PolynomialEnvelope(cutoff=cutoff, p=5),
        )

        # node things
        self.node_attribute_generator = node_attribute_generator
        self.initial_node_embedding = PerElementEmbedding(nodes.channels)

        # message passing
        current_node_irreps = [o3.Irrep("0e")]
        self.layers: UniformModuleList[MACELayer] = UniformModuleList([])

        for i in range(layers):
            # only use residual skip after the first layer
            use_residual = i != 0
            final_layer = i == layers - 1
            layer = MACELayer(
                irreps_in=current_node_irreps,
                nodes=nodes,
                correlation=correlation,
                sph_harmonics=sph_harmonics,
                radial_basis_features=n_radial,
                mlp=weights_mlp,
                use_sc=use_self_connection,
                aggregation=neighbour_aggregation,
                residual=use_residual,
                final_layer=final_layer,
            )
            self.layers.append(layer)
            current_node_irreps = [ir for _, ir in layer.irreps_out]

        self.readouts: UniformModuleList[ReadOut] = UniformModuleList(
            [LinearReadOut(nodes.hidden_irreps()) for _ in range(layers - 1)]
            + [
                NonLinearReadOut(
                    self.layers[-1].irreps_out, hidden_dim=readout_width
                )
            ],
        )

        self.scaler = LocalEnergiesScaler()

    def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
        # pre-compute some things
        vectors = neighbour_vectors(graph)
        sph_harmonics = self.spherical_harmonics(vectors)
        edge_features = self.radial_expansion(
            neighbour_distances(graph).view(-1, 1)
        )
        node_attributes = self.node_attribute_generator(graph.Z)

        # generate initial node features
        node_features = self.initial_node_embedding(graph.Z)

        # update node features through message passing layers
        per_atom_energies = []
        for layer, readout in zip(self.layers, self.readouts):
            node_features = layer(
                node_features,
                node_attributes,
                sph_harmonics,
                edge_features,
                graph,
            )
            per_atom_energies.append(readout(node_features))

        # sum up the per-atom energies
        local_energies = torch.sum(
            torch.stack(per_atom_energies), dim=0
        ).squeeze()

        # return scaled local energy predictions
        return {"local_energies": self.scaler(local_energies, graph)}


DEFAULT_MLP_CONFIG: Final[MLPConfig] = {
    "hidden_depth": 3,
    "hidden_features": 64,
    "activation": "SiLU",
}


[docs] class MACE(_BaseMACE): r""" The `MACE <https://arxiv.org/abs/2206.07697>`__ architecture. One-hot encodings of the atomic numbers are used to condition the ``TensorProduct`` update in the residual connection of the message passing layers, as well as the contractions in the message passing layers. Following the notation used in `ACEsuite/mace <https://github.com/ACEsuit/mace>`__, the first layer in this model is a ``RealAgnosticInteractionBlock``. Subsequent layers are then ``RealAgnosticResidualInteractionBlock``\ s Please cite the following if you use this model in your research: .. code-block:: bibtex @misc{Batatia2022MACE, title = { MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields }, author = { Batatia, Ilyes and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Simm, Gregor N. C. and Ortner, Christoph and Cs{\'a}nyi, G{\'a}bor }, year = {2022}, doi = {10.48550/arXiv.2206.07697}, } Parameters ---------- elements list of elements that this MACE model will be able to handle. cutoff radial cutoff (in Å) for the radial expansion (and message passing) n_radial number of bases to expand the radial distances into radial_expansion type of radial expansion to use. See :class:`~graph_pes.models.components.distances.DistanceExpansion` for available options weights_mlp configuration for the MLPs that map the radial basis functions to the weights of the interactions' tensor products channels the multiplicity of the node features corresponding to each irrep specified in ``hidden_irreps`` hidden_irreps string representations of the :class:`e3nn.o3.Irrep`\ s to use for representing the node features between each message passing layer l_max the highest order to consider in: * the spherical harmonics expansion of the neighbour vectors * the irreps of node features used within each message passing layer layers number of message passing layers correlation maximum correlation (body-order) of the messages aggregation the type of aggregation to use when creating total messages from neigbour messages :math:`m_{j \rightarrow i}` self_connection whether to use self-connections in the message passing layers readout_width the width of the MLP used to read out the per-atom energies after the final message passing layer Examples -------- Basic usage: .. code-block:: python >>> from graph_pes.models import MACE >>> model = MACE( ... elements=["H", "C", "N", "O"], ... cutoff=5.0, ... channels=16, ... radial_expansion="Bessel", ... ) Specification in a YAML file: .. code-block:: yaml model: +MACE: elements: [H, C, N, O] cutoff: 5.0 radial_expansion: Bessel # change from the default MLP config: weights_mlp: hidden_depth: 2 hidden_features: 16 activation: SiLU """ # noqa: E501 def __init__( self, elements: list[str], # radial things cutoff: float = DEFAULT_CUTOFF, n_radial: int = 8, radial_expansion: type[DistanceExpansion] | str = "Bessel", weights_mlp: MLPConfig = DEFAULT_MLP_CONFIG, # node things channels: int = 128, hidden_irreps: str | list[str] = "0e + 1o", # message passing things l_max: int = 3, layers: int = 2, correlation: int = 3, aggregation: NeighbourAggregationMode = "constant_fixed", self_connection: bool = True, # readout readout_width: int = 16, ): Z_embedding = AtomicOneHot(elements) Z_dim = len(elements) hidden_irrep_s = parse_irreps(hidden_irreps) nodes = NodeDescription( channels=channels, attributes=Z_dim, hidden_features=hidden_irrep_s, ) super().__init__( cutoff=cutoff, n_radial=n_radial, radial_expansion=radial_expansion, weights_mlp={**DEFAULT_MLP_CONFIG, **weights_mlp}, nodes=nodes, node_attribute_generator=Z_embedding, l_max=l_max, layers=layers, correlation=correlation, neighbour_aggregation=aggregation, use_self_connection=self_connection, readout_width=readout_width, )
[docs] class ZEmbeddingMACE(_BaseMACE): """ A variant of MACE that uses a fixed-size (``z_embed_dim``) per-element embedding of the atomic numbers to condition the ``TensorProduct`` update in the residual connection of the message passing layers, as well as the contractions in the message passing layers. Please cite the following if you use this model in your research: .. code-block:: bibtex @misc{Batatia2022MACE, title = { MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields }, author = { Batatia, Ilyes and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Simm, Gregor N. C. and Ortner, Christoph and Cs{\'a}nyi, G{\'a}bor }, year = {2022}, doi = {10.48550/arXiv.2206.07697}, } All paramters are identical to :class:`~graph_pes.models.MACE`, except for the following: - ``elements`` is not required or used here - ``z_embed_dim`` controls size of the per-element embedding """ # noqa: E501 def __init__( self, z_embed_dim: int = 4, # radial things cutoff: float = DEFAULT_CUTOFF, n_radial: int = 8, radial_expansion: type[DistanceExpansion] | str = "Bessel", weights_mlp: MLPConfig = DEFAULT_MLP_CONFIG, # node things channels: int = 128, hidden_irreps: str | list[str] = "0e + 1o", # message passing things l_max: int = 3, layers: int = 2, correlation: int = 3, aggregation: NeighbourAggregationMode = "constant_fixed", self_connection: bool = True, # readout readout_width: int = 16, ): Z_embedding = PerElementEmbedding(z_embed_dim) hidden_irrep_s = parse_irreps(hidden_irreps) nodes = NodeDescription( channels=channels, attributes=z_embed_dim, hidden_features=hidden_irrep_s, ) super().__init__( cutoff=cutoff, n_radial=n_radial, radial_expansion=radial_expansion, weights_mlp={**DEFAULT_MLP_CONFIG, **weights_mlp}, nodes=nodes, node_attribute_generator=Z_embedding, l_max=l_max, layers=layers, correlation=correlation, neighbour_aggregation=aggregation, use_self_connection=self_connection, readout_width=readout_width, )