Source code for graph_pes.models.components.distances

from __future__ import annotations

import math
import sys
from abc import ABC, abstractmethod

import torch

from graph_pes.utils.misc import to_significant_figures

[docs] class DistanceExpansion(torch.nn.Module, ABC): r""" Abstract base class for an expansion function, :math:`\phi(r) : [0, r_{\text{cutoff}}] \rightarrow \mathbb{R}^{n_\text{features}}`. Subclasses should implement :meth:`expand`, which must also work over batches: .. math:: \phi(r) : [0, r_{\text{cutoff}}]^{n_\text{batch} \times 1} \rightarrow \mathbb{R}^{n_\text{batch} \times n_\text{features}} Parameters ---------- n_features The number of features to expand into. cutoff The cutoff radius. trainable Whether the expansion parameters are trainable. """ def __init__(self, n_features: int, cutoff: float, trainable: bool = True): super().__init__() self.n_features = n_features self.register_buffer("cutoff", torch.tensor(cutoff)) self.trainable = trainable
[docs] @abstractmethod def expand(self, r: torch.Tensor) -> torch.Tensor: r""" Perform the expansion. Parameters ---------- r : torch.Tensor The distances to expand. Guaranteed to have shape :math:`(..., 1)`. """
[docs] def forward(self, r: torch.Tensor) -> torch.Tensor: """ Call the expansion as normal in PyTorch. Parameters ---------- r The distances to expand. """ if r.shape[-1] != 1: r = r.unsqueeze(-1) return self.expand(r)
# for mypy etc. def __call__(self, r: torch.Tensor) -> torch.Tensor: return super().__call__(r) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(n_features={self.n_features}, " f"cutoff={to_significant_figures(self.cutoff.item(), 3)}, " f"trainable={self.trainable})" )
def get_distance_expansion( thing: str | type[DistanceExpansion], ) -> type[DistanceExpansion]: """ Get a distance expansion class by it's name. Parameters ---------- name The name of the distance expansion class. Example ------- >>> get_distance_expansion("Bessel") <class 'graph_pes.models.components.distances.Bessel'> """ if isinstance(thing, type) and issubclass(thing, DistanceExpansion): return thing try: klass = getattr(sys.modules[__name__], thing) except AttributeError: raise ValueError(f"Unknown distance expansion type: {thing}") from None if not isinstance(klass, type) or not issubclass(klass, DistanceExpansion): raise ValueError(f"{thing} is not a DistanceExpansion") from None return klass
[docs] class Bessel(DistanceExpansion): r""" The Bessel expansion: .. math:: \phi_{n}(r) = \sqrt{\frac{2}{r_{\text{cut}}}} \frac{\sin(n \pi \frac{r}{r_\text{cut}})}{r} \quad n \in [1, n_\text{features}] where :math:`r_\text{cut}` is the cutoff radius and :math:`n` is the order of the Bessel function, as introduced in `Directional Message Passing for Molecular Graphs <>`_. .. code:: import torch from graph_pes.models.components.distances import Bessel import matplotlib.pyplot as plt cutoff = 5.0 bessel = Bessel(n_features=4, cutoff=cutoff) r = torch.linspace(0, cutoff, 101) # (101,) with torch.no_grad(): embedding = bessel(r) # (101, 4) plt.plot(r / cutoff, embedding) plt.xlabel(r"$r / r_c$") .. image:: bessel.svg :align: center Parameters ---------- n_features The number of features to expand into. cutoff The cutoff radius. trainable Whether the expansion parameters are trainable. Attributes ---------- frequencies :math:`n`, the frequencies of the Bessel functions. """ def __init__(self, n_features: int, cutoff: float, trainable: bool = True): super().__init__(n_features, cutoff, trainable) self.frequencies = torch.nn.Parameter( torch.arange(1, n_features + 1) * math.pi / cutoff, requires_grad=trainable, ) self.pre_factor = torch.sqrt(torch.tensor(2 / cutoff)) def expand(self, r: torch.Tensor) -> torch.Tensor: numerator = self.pre_factor * torch.sin(r * self.frequencies) # we avoid dividing by zero by replacing any zero elements with 1 denominator = torch.where(r == 0, torch.tensor(1.0, device=r.device), r) return numerator / denominator
[docs] class GaussianSmearing(DistanceExpansion): r""" A Gaussian smearing expansion: .. math:: \phi_{n}(r) = \exp\left(-\frac{(r - \mu_n)^2}{2\sigma^2}\right) \quad n \in [1, n_\text{features}] where :math:`\mu_n` is the center of the :math:`n`'th Gaussian and :math:`\sigma` is a width shared across all the Gaussians. .. code:: import torch from graph_pes.models.components.distances import GaussianSmearing import matplotlib.pyplot as plt cutoff = 5.0 gaussian = GaussianSmearing(n_features=4, cutoff=cutoff) r = torch.linspace(0, cutoff, 101) # (101,) with torch.no_grad(): embedding = gaussian(r) # (101, 4) plt.plot(r / cutoff, embedding) plt.xlabel(r"$r / r_c$") .. image:: gaussian.svg :align: center Parameters ---------- n_features The number of features to expand into. cutoff The cutoff radius. trainable Whether the expansion parameters are trainable. Attributes ---------- centers :math:`\mu_n`, the centers of the Gaussians. coef :math:`\frac{1}{2\sigma^2}`, the coefficient of the exponent. """ def __init__( self, n_features: int, cutoff: float, trainable: bool = True, ): super().__init__(n_features, cutoff, trainable) sigma = cutoff / n_features self.coef = torch.nn.Parameter( torch.tensor(-1 / (2 * sigma**2)), requires_grad=trainable, ) self.centers = torch.nn.Parameter( torch.linspace(0, cutoff, n_features), requires_grad=trainable, ) def expand(self, r: torch.Tensor) -> torch.Tensor: offsets = r - self.centers return torch.exp(self.coef * offsets**2)
[docs] class SinExpansion(DistanceExpansion): r""" A sine expansion: .. math:: \phi_{n}(r) = \sin\left(\frac{n \pi r}{r_\text{cut}}\right) \quad n \in [1, n_\text{features}] where :math:`r_\text{cut}` is the cutoff radius and :math:`n` is the frequency of the sine function. .. code:: import torch from graph_pes.models.components.distances import SinExpansion import matplotlib.pyplot as plt cutoff = 5.0 sine = SinExpansion(n_features=4, cutoff=cutoff) r = torch.linspace(0, cutoff, 101) # (101,) with torch.no_grad(): embedding = sine(r) # (101, 4) plt.plot(r / cutoff, embedding) plt.xlabel(r"$r / r_c$") .. image:: sin.svg :align: center Parameters ---------- n_features The number of features to expand into. cutoff The cutoff radius. trainable Whether the expansion parameters are trainable. Attributes ---------- frequencies :math:`n`, the frequencies of the sine functions. """ def __init__(self, n_features: int, cutoff: float, trainable: bool = True): super().__init__(n_features, cutoff, trainable) self.frequencies = torch.nn.Parameter( torch.arange(1, n_features + 1) * math.pi / cutoff, requires_grad=trainable, ) def expand(self, r: torch.Tensor) -> torch.Tensor: return torch.sin(r * self.frequencies)
[docs] class ExponentialRBF(DistanceExpansion): r""" The exponential radial basis function expansion, as introduced in `PhysNet: A Neural Network for Predicting Energies, Forces, Dipole Moments and Partial Charges <>`_: .. math:: \phi_{n}(r) = \exp\left(-\beta_n \cdot(\exp(-r_{ij}) - \mu_n)^2 \right) \quad n \in [1, n_\text{features}] where :math:`\beta_n` and :math:`\mu_n` are the (inverse) width and center of the :math:`n`'th expansion, respectively. Following `PhysNet <>`_, :math:`\mu_n` are evenly spaced between :math:`\exp(-r_{\text{cut}})` and :math:`1`, and: .. math:: \left( \frac{1}{\sqrt{2}\beta_n} \right)^2 = \frac{1 - \exp(-r_{\text{cut}})}{n_\text{features}} .. code:: import torch from graph_pes.models.components.distances import ExponentialRBF import matplotlib.pyplot as plt cutoff = 5.0 rbf = ExponentialRBF(n_features=10, cutoff=cutoff) r = torch.linspace(0, cutoff, 101) # (101,) with torch.no_grad(): embedding = rbf(r) # (101, 10) plt.plot(r / cutoff, embedding) plt.xlabel(r"$r / r_c$") .. image:: erbf.svg :align: center Parameters ---------- n_features The number of features to expand into. cutoff The cutoff radius. trainable Whether the expansion parameters are trainable. Attributes ---------- β :math:`\beta_n`, the (inverse) widths of each basis. centers :math:`\mu_n`, the centers of each basis. """ def __init__(self, n_features: int, cutoff: float, trainable: bool = True): super().__init__(n_features, cutoff, trainable) c = torch.exp(-torch.tensor(cutoff)) self.beta = torch.nn.Parameter( torch.ones(n_features) / (2 * (1 - c) / n_features) ** 2, requires_grad=trainable, ) self.centers = torch.nn.Parameter( torch.linspace(c.item(), 1, n_features), requires_grad=trainable, ) def expand(self, r: torch.Tensor) -> torch.Tensor: offsets = torch.exp(-r) - self.centers return torch.exp(-self.beta * offsets**2)
[docs] class Envelope(torch.nn.Module): """ Any envelope function, :math:`E(r)`, for smoothing potentials must implement a forward method that takes in a tensor of distances and returns a tensor of the same shape, where the values outside the cutoff are set to zero. """
[docs] def forward(self, r: torch.Tensor) -> torch.Tensor: """ Perform the envelope function. Parameters ---------- r : torch.Tensor The distances to envelope. """ ...
def __call__(self, r: torch.Tensor) -> torch.Tensor: return super().__call__(r)
[docs] class PolynomialEnvelope(Envelope): r""" A thrice differentiable envelope function. .. math:: E_p(r) = 1 - \frac{(p+1)(p+2)}{2}\cdot r^p + p(p+2) \cdot r^{p+1} - \frac{p(p+1)}{2}\cdot d^{p+2} where :math:`r_\text{cut}` is the cutoff radius, and :math:`p \in \mathbb{N}`. Parameters ---------- cutoff : float The cutoff radius. p: int The order of the envelope function. """ def __init__(self, cutoff: float, p: int = 6): super().__init__() self.cutoff = cutoff self.p = p self.register_buffer( "coefficients", torch.tensor( [ -(p + 1) * (p + 2) / 2, p * (p + 2), -(p * (p + 1)) / 2, ] ), ) self.register_buffer("powers", torch.arange(p, p + 3)) def forward(self, r: torch.Tensor) -> torch.Tensor: powers = (r.unsqueeze(-1) / self.cutoff) ** self.powers envelope = 1 + (powers * self.coefficients).sum(dim=-1) return torch.where( r <= self.cutoff, envelope, torch.tensor(0.0, device=r.device) ) def __repr__(self): return f"PolynomialEnvelope(cutoff={self.cutoff}, p={self.p})"
[docs] class CosineEnvelope(Envelope): r""" A cosine envelope function. .. math:: E_c(r) = \frac{1}{2}\left(1 + \cos\left(\frac{\pi r}{r_\text{cut}} \right)\right) where :math:`r_\text{cut}` is the cutoff radius. Parameters ---------- cutoff : float The cutoff radius. """ def __init__(self, cutoff: float): super().__init__() self.cutoff = cutoff def forward(self, r: torch.Tensor) -> torch.Tensor: cos = 0.5 * (1 + torch.cos(math.pi * r / self.cutoff)) return torch.where(r <= self.cutoff, cos, torch.tensor(0.0)) def __repr__(self): return f"CosineEnvelope(cutoff={self.cutoff})"
[docs] class SmoothOnsetEnvelope(Envelope): r""" A smooth cutoff function with an onset. .. math:: f(r, r_o, r_c) = \begin{cases} \hfill 1 \hfill & \text{if } r < r_o \\ \frac{(r_c - r)^2 (r_c + 2r - 3r_o)}{(r_c - r_o)^3} & \text{if } r_o \leq r < r_c \\ \hfill 0 \hfill & \text{if } r \geq r_c \end{cases} where :math:`r_o` is the onset radius and :math:`r_c` is the cutoff radius. Parameters ---------- cutoff : float The cutoff radius. onset : float The onset radius. """ # noqa: E501 def __init__(self, cutoff: float, onset: float): super().__init__() if onset >= cutoff: raise ValueError("Onset must be less than cutoff") self.register_buffer("cutoff", torch.tensor(cutoff)) self.register_buffer("onset", torch.tensor(onset)) def forward(self, r: torch.Tensor) -> torch.Tensor: return torch.where( r < self.onset, torch.tensor(1.0, device=r.device), torch.where( r < self.cutoff, (self.cutoff - r) ** 2 * (self.cutoff + 2 * r - 3 * self.onset) / (self.cutoff - self.onset) ** 3, torch.tensor(0.0, device=r.device), ), ) def __repr__(self): return f"SmoothOnsetEnvelope(cutoff={self.cutoff}, onset={self.onset})"