Source code for graph_pes.models.components.aggregation

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Literal

import torch

from graph_pes.atomic_graph import (
    AtomicGraph,
    number_of_atoms,
    number_of_edges,
    number_of_neighbours,
    sum_over_neighbours,
)
from graph_pes.utils.misc import (
    is_being_documented,
    left_aligned_div,
    uniform_repr,
)

if TYPE_CHECKING or is_being_documented():
    NeighbourAggregationMode = Literal[
        "sum", "mean", "constant_fixed", "constant_learnable", "sqrt"
    ]
else:
    NeighbourAggregationMode = str


[docs] class NeighbourAggregation(ABC, torch.nn.Module): r""" An abstract base class for aggregating values over neighbours: .. math:: X_i^\prime = \text{Agg}_{j \in \mathcal{N}_i} \left[X_j\right] where :math:`\mathcal{N}_i` is the set of neighbours of atom :math:`i`, :math:`X` has shape ``(E, ...)``, :math:`X^\prime` has shape ``(N, ...)`` and ``E`` and ``N`` are the number of edges and atoms in the graph, respectively. """
[docs] @abstractmethod def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: """Aggregate x over neighbours."""
[docs] def pre_fit(self, graphs: AtomicGraph) -> None: """ Calculate any quantities that are dependent on the graph structure that should be fixed before prediction. Default implementation does nothing. Parameters ---------- graphs A batch of graphs to pre-fit to. """
# type hints for mypy etc. def __call__(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: return super().__call__(x, graph)
[docs] @staticmethod def parse( mode: Literal[ "sum", "mean", "constant_fixed", "constant_learnable", "sqrt" ], ) -> NeighbourAggregation: """ Evaluates the following map: .. list-table:: :widths: 30 70 :header-rows: 1 * - Mode - Aggregation * - ``"sum"`` - :class:`SumNeighbours() <SumNeighbours>` * - ``"mean"`` - :class:`MeanNeighbours() <MeanNeighbours>` * - ``"constant_fixed"`` - :class:`ScaledSumNeighbours(learnable=False) <ScaledSumNeighbours>` * - ``"constant_learnable"`` - :class:`ScaledSumNeighbours(learnable=True) <ScaledSumNeighbours>` * - ``"sqrt"`` - :class:`VariancePreservingSumNeighbours() <VariancePreservingSumNeighbours>` Parameters ---------- mode The neighbour aggregation mode to parse. Returns ------- NeighbourAggregation The parsed neighbour aggregation mode. """ # noqa: E501 if mode == "sum": return SumNeighbours() elif mode == "mean": return MeanNeighbours() elif mode == "constant_fixed": return ScaledSumNeighbours(learnable=False) elif mode == "constant_learnable": return ScaledSumNeighbours(learnable=True) elif mode == "sqrt": return VariancePreservingSumNeighbours() else: raise ValueError(f"Unknown neighbour aggregation mode: {mode}")
[docs] class SumNeighbours(NeighbourAggregation): r""" Sum over neighbours: .. math:: X_i^\prime = \sum_{j \in \mathcal{N}_i} X_j """
[docs] def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: return sum_over_neighbours(x, graph)
[docs] class MeanNeighbours(NeighbourAggregation): r""" Take an average over neighbours: .. math:: X_i^\prime = \frac{1}{|\mathcal{N}_i|} \sum_{j \in \mathcal{N}_i} X_j where :math:`|\mathcal{N}_i|` is the number of neighbours of atom :math:`i` (including the central atom). .. note:: This aggregation can lead to un-physical discontinuities in the PES as neighbours enter or leave the radial cutoff. """
[docs] def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: return left_aligned_div( sum_over_neighbours(x, graph), number_of_neighbours(graph, include_central_atom=True), )
[docs] class ScaledSumNeighbours(NeighbourAggregation): r""" Scale the sum over neighbours by a learnable or fixed constant, :math:`s`: .. math:: X_i^\prime = \frac{1}{s} \sum_{j \in \mathcal{N}_i} X_j :math:`s` defaults to 1.0, but is set to the average number of neighbours of each atom in the training set passed to :meth:`pre_fit`. Parameters ---------- learnable If ``True``, the scale is a learnable parameter. If ``False``, the scale is a fixed constant. """ def __init__(self, learnable: bool = False): super().__init__() self.scale = torch.nn.Parameter( torch.tensor(1.0), requires_grad=learnable )
[docs] def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: return sum_over_neighbours(x, graph) / self.scale
[docs] def pre_fit(self, graphs: AtomicGraph) -> None: """ Set the scale equal to the average number of neighbours in the training set. """ avg_neighbours = number_of_edges(graphs) / number_of_atoms(graphs) self.scale.data = torch.tensor(avg_neighbours)
def __repr__(self) -> str: return uniform_repr( self.__class__.__name__, scale=f"{self.scale.item():.3f}", learnable=self.scale.requires_grad, )
[docs] class VariancePreservingSumNeighbours(NeighbourAggregation): r""" Scale the sum over neighbours by the square root of the number of neighbours: .. math:: X_i^\prime = \frac{1}{\sqrt{|\mathcal{N}_i|}} \sum_{j \in \mathcal{N}_i} X_j where :math:`|\mathcal{N}_i|` is the number of neighbours of atom :math:`i` (including the central atom). .. note:: This aggregation can lead to un-physical discontinuities in the PES as neighbours enter or leave the radial cutoff. """ # noqa: E501
[docs] def forward(self, x: torch.Tensor, graph: AtomicGraph) -> torch.Tensor: return left_aligned_div( sum_over_neighbours(x, graph), torch.sqrt(number_of_neighbours(graph, include_central_atom=True)), )