Source code for graph_pes.utils.nn

from __future__ import annotations

from functools import reduce
from typing import (
    Any,
    Generic,
    Iterable,
    Iterator,
    Sequence,
    TypedDict,
    TypeVar,
)

import torch
import torch.nn
from ase.data import atomic_numbers, chemical_symbols, covalent_radii
from torch import Tensor

from graph_pes.utils.misc import left_aligned_mul

from .misc import (
    MAX_Z,
    pairs,
    to_significant_figures,
    uniform_repr,
)

V = TypeVar("V", bound=torch.nn.Module)


class UniformModuleDict(torch.nn.ModuleDict, Generic[V]):
    """
    A :class:`torch.nn.ModuleDict` sub-class for cases where
    the values are all of the same type.

    Examples
    --------
    >>> from graph_pes.utils.nn import UniformModuleDict
    >>> from torch.nn import Linear
    >>> linear_dict = UniformModuleDict(a=Linear(10, 5), b=Linear(5, 1))
    """

    def __init__(self, **modules: V):
        super().__init__(modules)

    def values(self) -> Iterable[V]:
        return super().values()  # type: ignore

    def items(self) -> Iterable[tuple[str, V]]:
        return super().items()  # type: ignore

    def __getitem__(self, key: str) -> V:
        return super().__getitem__(key)  # type: ignore

    def __setitem__(self, key: str, module: V) -> None:  # type: ignore
        super().__setitem__(key, module)

    def pop(self, key: str) -> V:
        return super().pop(key)  # type: ignore


class UniformModuleList(torch.nn.ModuleList, Sequence[V]):
    """
    A :class:`torch.nn.ModuleList` sub-class for cases where
    the values are all of the same type.

    Examples
    --------
    >>> from graph_pes.utils.nn import UniformModuleList
    >>> from torch.nn import Linear
    >>> linear_list = UniformModuleList(Linear(10, 5), Linear(5, 1))
    """

    def __init__(self, modules: Iterable[V]):
        super().__init__(modules)

    def __getitem__(self, idx: int) -> V:  # type: ignore
        return super().__getitem__(idx)  # type: ignore

    def __setitem__(self, idx: int, value: V) -> None:  # type: ignore
        super().__setitem__(idx, value)

    def append(self, module: V) -> None:  # type: ignore
        super().append(module)

    def extend(self, modules: Iterable[V]) -> None:  # type: ignore
        super().extend(modules)

    def insert(self, idx: int, module: V) -> None:  # type: ignore
        super().insert(idx, module)

    def pop(self, idx: int) -> V:  # type: ignore
        return super().pop(idx)  # type: ignore

    def __iter__(self) -> Iterator[V]:
        return super().__iter__()  # type: ignore


[docs] class MLPConfig(TypedDict): """ A TypedDict helper class for configuring an :class:`MLP`. Examples -------- Specify this in a config file: .. code-block:: yaml mlp: hidden_depth: 3 hidden_features: 64 activation: SiLU """ hidden_depth: int """The number of hidden layers in the MLP.""" hidden_features: int """The number of features in the hidden layers.""" activation: str """The activation function to use."""
[docs] class MLP(torch.nn.Module): """ A multi-layer perceptron model, alternating linear layers and activations. Parameters ---------- layers The number of nodes in each layer. activation The activation function to use: either a named activation function from `torch.nn`, or a `torch.nn.Module` instance. activate_last Whether to apply the activation function after the last linear layer. bias Whether to include bias terms in the linear layers. Examples -------- >>> import torch >>> from graph_pes.utils.nn import MLP >>> model = MLP([10, 5, 1]) >>> model MLP(10 → 5 → 1, activation=CELU()) >>> MLP([10, 5, 1], activation=torch.nn.ReLU()) MLP(10 → 5 → 1, activation=ReLU()) >>> MLP([10, 5, 1], activation="Tanh") MLP(10 → 5 → 1, activation=Tanh()) """ def __init__( self, layers: list[int], activation: str | torch.nn.Module = "CELU", activate_last: bool = False, bias: bool = True, ): super().__init__() self.activation = ( parse_activation(activation) if isinstance(activation, str) else activation ) self.activate_last = activate_last self.linear_layers = torch.nn.ModuleList( [ torch.nn.Linear(_in, _out, bias=bias) for _in, _out in pairs(layers) ] )
[docs] def forward(self, x: Tensor) -> Tensor: """ Perform a forward pass through the network. Parameters ---------- x The input to the network. """ for i, linear in enumerate(self.linear_layers): x = linear(x) last_layer = i == len(self.linear_layers) - 1 if not last_layer or self.activate_last: x = self.activation(x) return x
@property def input_size(self): """The size of the input to the network.""" return self.linear_layers[0].in_features @property def output_size(self): """The size of the output of the network.""" return self.linear_layers[-1].out_features @property def layer_widths(self): """The widths of the layers in the network.""" inputs = [layer.in_features for layer in self.linear_layers] return inputs + [self.output_size] def __repr__(self): layers = " → ".join(map(str, self.layer_widths)) return uniform_repr( self.__class__.__name__, layers, activation=self.activation, stringify=False, )
[docs] @classmethod def from_config( cls, config: MLPConfig, input_features: int, output_features: int, bias: bool = True, ) -> MLP: """ Create an :class:`MLP` from a configuration. """ return cls( layers=[input_features] + [config["hidden_features"]] * config["hidden_depth"] + [output_features], activation=config["activation"], bias=bias, )
class ShiftedSoftplus(torch.nn.Module): def __init__(self): super().__init__() self.shift = torch.log(torch.tensor(2.0)).item() def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.softplus(x) - self.shift def __repr__(self): return uniform_repr(self.__class__.__name__) def parse_activation(act: str) -> torch.nn.Module: """ Parse a string into a PyTorch activation function. Parameters ---------- act The activation function to parse. Returns ------- torch.nn.Module The parsed activation function. """ activation = getattr(torch.nn, act, None) if activation is None: raise ValueError(f"Activation function {act} not found in `torch.nn`.") return activation() def prod(iterable): return reduce(lambda x, y: x * y, iterable, 1)
[docs] class PerElementParameter(torch.nn.Parameter): """ A subclass of :class:`torch.nn.Parameter` that is indexed by atomic number/s. Crucially, this subclass overrides the :meth:`numel` method, for accurately counting the number of relevant and learnable parameters. Examples -------- Imagine the case where you have a model parameter with a value for each element in the periodic table. If you only train the model on a dataset containing a few elements, you don't want to count the total number of parameters, as this will be unnecessarily large. >>> # don't do this! >>> per_element_parameter = torch.nn.Parameter(torch.randn(119)) >>> per_element_parameter.numel() 119 >>> per_element_parameter Parameter containing: tensor([ 1.2838e-01, -1.4507e+00, 1.3465e-01, -9.5786e-01, ..., -1.3329e+00, -1.5515e+00, 2.1106e+00, -9.7268e-01], requires_grad=True) >>> # do this instead >>> per_element_paramter = PerElementParameter.of_shape((1,)) >>> per_element_parameter.register_elements([1, 6, 8]) >>> per_element_parameter.numel() 3 >>> per_element_parameter PerElementParameter({'O': -0.278, 'H': 0.157, 'C': -0.0379}, trainable=True) ``graph-pes-train`` automatically registers all elements that a model encounters during training, so you rarely need to call :meth:`register_elements` yourself. """ def __new__( cls, data: Tensor, requires_grad: bool = True ) -> PerElementParameter: pep = super().__new__(cls, data, requires_grad=requires_grad) pep._is_per_element_param = True # type: ignore return pep # type: ignore def __init__(self, data: Tensor, requires_grad: bool = True): super().__init__() # set extra state self._accessed_Zs: set[int] = set() # set this to an arbitrary value: this gets updated post-init self._index_dims: int = 1
[docs] def register_elements(self, Zs: Iterable[int]) -> None: """ Register the elements that are relevant for the parameter. This is typically only used internally - you shouldn't call this yourself in any of your model definitions. """ self._accessed_Zs.update(sorted(Zs))
[docs] @classmethod def of_shape( cls, shape: tuple[int, ...] = (), index_dims: int = 1, default_value: float | None = None, requires_grad: bool = True, ) -> PerElementParameter: """ Create a :class:`PerElementParameter` with a given shape for each element in the periodic table. Parameters ---------- shape The shape of the parameter for each element. index_dims The number of dimensions to index by. default_value The value to initialise the parameter with. If ``None``, the parameter is initialised with random values. requires_grad Whether the parameter should be learnable. Returns ------- PerElementParameter The created parameter. Examples -------- Create a parameter intended to be indexed by a single atomic number, i.e. ``pep[Z]``: >>> PerElementParameter.of_shape((3,)).shape torch.Size([119, 3]) >>> PerElementParameter.of_shape((3, 4)).shape torch.Size([119, 3, 4]) Create a parameter intended to be indexed by two atomic numbers, i.e. ``pep[Z1, Z2]``: >>> PerElementParameter.of_shape((3,), index_dims=2).shape torch.Size([119, 119, 3]) """ actual_shape = tuple([MAX_Z + 1] * index_dims) + shape if default_value is not None: data = torch.full(actual_shape, float(default_value)) else: data = torch.randn(actual_shape) psp = PerElementParameter(data, requires_grad=requires_grad) psp._index_dims = index_dims return psp
[docs] @classmethod @torch.no_grad() def from_dict( cls, requires_grad: bool = True, default_value: float = 0.0, **values: float, ) -> PerElementParameter: """ Create a :class:`PerElementParameter` containing a single value for each element in the periodic table from a dictionary of values. Parameters ---------- requires_grad Whether the parameter should be learnable. default_value The value to initialise the parameter with. If ``None``, the parameter is initialised with random values. values A dictionary of values, indexed by element symbol. Returns ------- PerElementParameter The created parameter. Examples -------- >>> from graph_pes.utils.nn import PerElementParameter >>> pep = PerElementParameter.from_dict(H=1.0, O=2.0) >>> pep.register_elements([1, 6, 8]) >>> pep PerElementParameter({'H': 1.0, 'C': 0.0, 'O': 2.0}, trainable=True) """ pep = PerElementParameter.of_length( 1, requires_grad=requires_grad, default_value=default_value ) for element_symbol, value in values.items(): if element_symbol not in chemical_symbols: raise ValueError(f"Unknown element: {element_symbol}") Z = chemical_symbols.index(element_symbol) pep[Z] = value pep.register_elements(atomic_numbers[v] for v in values) return pep
[docs] @classmethod def of_length( cls, length: int, index_dims: int = 1, default_value: float | None = None, requires_grad: bool = True, ) -> PerElementParameter: """ Alias for ``PerElementParameter.of_shape((length,), **kwargs)``. """ return PerElementParameter.of_shape( (length,), index_dims, default_value, requires_grad )
[docs] @classmethod @torch.no_grad() def covalent_radii( cls, scaling_factor: float = 1.0, ) -> PerElementParameter: """ Create a :class:`PerElementParameter` containing the covalent radii of each element in the periodic table. """ pep = PerElementParameter.of_length(1, default_value=1.0) for Z in range(1, MAX_Z + 1): pep[Z] = torch.tensor(covalent_radii[Z]) * scaling_factor return pep
[docs] def numel(self) -> int: n_elements = len(self._accessed_Zs) accessed_parameters = n_elements**self._index_dims per_element_size = prod(self.shape[self._index_dims :]) return accessed_parameters * per_element_size
# needed for de/serialization def __reduce_ex__(self, proto): return ( _rebuild_per_element_parameter, (self.data, self.requires_grad, torch._utils._get_obj_state(self)), ) def __instancecheck__(self, instance) -> bool: return super().__instancecheck__(instance) or ( # type: ignore[no-untyped-call] isinstance(instance, torch.Tensor) and getattr(instance, "_is_per_element_param", False) ) @torch.no_grad() def _repr( self, alias: str | None = None, more_info: dict[str, Any] | None = None, ) -> str: alias = alias or self.__class__.__name__ more_info = more_info or {} if "trainable" not in more_info: more_info["trainable"] = self.requires_grad if len(self._accessed_Zs) == 0: if self._index_dims == 1 and self.shape[1] == 1: return uniform_repr(alias, **more_info) return uniform_repr( alias, index_dims=self._index_dims, shape=tuple(self.shape[self._index_dims :]), **more_info, ) if self._index_dims == 1: if self.shape[1] == 1: d = { chemical_symbols[Z]: to_significant_figures(self[Z].item()) for Z in self._accessed_Zs } string = f"{alias}({str(d)}, " for k, v in more_info.items(): string += f"{k}={v}, " return string[:-2] + ")" elif len(self.shape) == 2: d = { chemical_symbols[Z]: self[Z].tolist() for Z in self._accessed_Zs } string = f"{alias}({str(d)}, " for k, v in more_info.items(): string += f"{k}={v}, " return string[:-2] + ")" if self._index_dims == 2 and self.shape[2] == 1: columns = [] columns.append( ["Z"] + [chemical_symbols[Z] for Z in self._accessed_Zs] ) for col_Z in self._accessed_Zs: row: list[str | float] = [chemical_symbols[col_Z]] for row_Z in self._accessed_Zs: row.append( to_significant_figures(self[col_Z, row_Z].item()) ) columns.append(row) widths = [max(len(str(x)) for x in col) for col in zip(*columns)] lines = [] for row in columns: line = "" for x, w in zip(row, widths): # right align line += f"{x:>{w}} " lines.append(line) table = "\n" + "\n".join(lines) return uniform_repr( alias, table, **more_info, ) return uniform_repr( alias, index_dims=self._index_dims, accessed_Zs=sorted(self._accessed_Zs), shape=tuple(self.shape[self._index_dims :]), **more_info, ) def __repr__(self, *, tensor_contents=None): return self._repr()
def _rebuild_per_element_parameter(data, requires_grad, state): psp = PerElementParameter(data, requires_grad) psp._accessed_Zs = state["_accessed_Zs"] psp._index_dims = state["_index_dims"] return psp
[docs] class PerElementEmbedding(torch.nn.Module): """ A per-element equivalent of :class:`torch.nn.Embedding`. Parameters ---------- dim The length of each embedding vector. Examples -------- >>> embedding = PerElementEmbedding(10) >>> len(graph["atomic_numbers"]) # number of atoms in the graph 24 >>> embedding(graph["atomic_numbers"]) <tensor of shape (24, 10)> """ def __init__(self, dim: int): super().__init__() self._embeddings = PerElementParameter.of_length(dim) def forward(self, Z: Tensor) -> Tensor: return self._embeddings[Z] def dim(self) -> int: return self._embeddings.shape[1] def __repr__(self) -> str: Zs = sorted(self._embeddings._accessed_Zs) return uniform_repr( self.__class__.__name__, dim=self._embeddings.shape[1], elements=[chemical_symbols[Z] for Z in Zs], ) def __call__(self, Z: Tensor) -> Tensor: return super().__call__(Z)
class HaddamardProduct(torch.nn.Module): def __init__( self, *components: torch.nn.Module, left_aligned: bool = False ): super().__init__() self.components: list[torch.nn.Module] = torch.nn.ModuleList(components) # type: ignore self.left_aligned = left_aligned def forward(self, x): out = torch.scalar_tensor(1) for component in self.components: if self.left_aligned: out = left_aligned_mul(out, component(x)) else: out = out * component(x) return out def learnable_parameters(module: torch.nn.Module) -> int: """Count the number of **learnable** parameters a module has.""" return sum(p.numel() for p in module.parameters() if p.requires_grad) class AtomicOneHot(torch.nn.Module): """ Takes a tensor of atomic numbers Z, and returns a one-hot encoding of the atomic numbers. Parameters ---------- n_elements The total number of expected atomic numbers. """ def __init__(self, elements: list[str]): super().__init__() self.elements = elements self.n_elements = len(elements) self.Z_to_idx: Tensor self.register_buffer( "Z_to_idx", # deliberately crazy value to catch errors torch.full((MAX_Z + 1,), fill_value=1234), ) for i, symbol in enumerate(elements): Z = atomic_numbers[symbol] self.Z_to_idx[Z] = i def forward(self, Z: Tensor) -> Tensor: internal_idx = self.Z_to_idx[Z] with torch.no_grad(): if (internal_idx == 1234).any(): unknown_Z = torch.unique(Z[internal_idx == 1234]) raise ValueError( f"Unknown elements: {unknown_Z}. " f"Expected one of {self.elements}" ) return torch.nn.functional.one_hot( internal_idx, self.n_elements ).float() def __repr__(self): return uniform_repr( self.__class__.__name__, elements=self.elements, )