PyTorch Helpers

class graph_pes.utils.nn.PerElementParameter(data, requires_grad=True)[source]

Bases: Parameter

A subclass of torch.nn.Parameter that is indexed by atomic number/s. Crucially, this subclass overrides the numel() method, for accurately counting the number of relevant and learnable parameters.

Examples

Imagine the case where you have a model parameter with a value for each element in the periodic table. If you only train the model on a dataset containing a few elements, you don’t want to count the total number of parameters, as this will be unnecessarily large.

>>> # don't do this!
>>> per_element_parameter = torch.nn.Parameter(torch.randn(119))
>>> per_element_parameter.numel()
119
>>> per_element_parameter
Parameter containing:
tensor([ 1.2838e-01, -1.4507e+00,  1.3465e-01, -9.5786e-01, ...,
        -1.3329e+00, -1.5515e+00,  2.1106e+00, -9.7268e-01],
   requires_grad=True)
>>> # do this instead
>>> per_element_paramter = PerElementParameter.of_shape((1,))
>>> per_element_parameter.register_elements([1, 6, 8])
>>> per_element_parameter.numel()
3
>>> per_element_parameter
PerElementParameter({'O': -0.278, 'H': 0.157, 'C': -0.0379}, trainable=True)

graph-pes-train automatically registers all elements that a model encounters during training, so you rarely need to call register_elements() yourself.

Return type:

PerElementParameter

register_elements(Zs)[source]

Register the elements that are relevant for the parameter.

This is typically only used internally - you shouldn’t call this yourself in any of your model definitions.

classmethod of_shape(
shape=(),
index_dims=1,
default_value=None,
requires_grad=True,
)[source]

Create a PerElementParameter with a given shape for each element in the periodic table.

Parameters:
  • shape (tuple[int, ...]) – The shape of the parameter for each element.

  • index_dims (int) – The number of dimensions to index by.

  • default_value (float | None) – The value to initialise the parameter with. If None, the parameter is initialised with random values.

  • requires_grad (bool) – Whether the parameter should be learnable.

Returns:

The created parameter.

Return type:

PerElementParameter

Examples

Create a parameter intended to be indexed by a single atomic number, i.e. pep[Z]:

>>> PerElementParameter.of_shape((3,)).shape
torch.Size([119, 3])
>>> PerElementParameter.of_shape((3, 4)).shape
torch.Size([119, 3, 4])

Create a parameter intended to be indexed by two atomic numbers, i.e. pep[Z1, Z2]:

>>> PerElementParameter.of_shape((3,), index_dims=2).shape
torch.Size([119, 119, 3])
classmethod from_dict(
requires_grad=True,
default_value=0.0,
**values,
)[source]

Create a PerElementParameter containing a single value for each element in the periodic table from a dictionary of values.

Parameters:
  • requires_grad (bool) – Whether the parameter should be learnable.

  • default_value (float) – The value to initialise the parameter with. If None, the parameter is initialised with random values.

  • values (float) – A dictionary of values, indexed by element symbol.

Returns:

The created parameter.

Return type:

PerElementParameter

Examples

>>> from graph_pes.utils.nn import PerElementParameter
>>> pep = PerElementParameter.from_dict(H=1.0, O=2.0)
>>> pep.register_elements([1, 6, 8])
>>> pep
PerElementParameter({'H': 1.0, 'C': 0.0, 'O': 2.0}, trainable=True)
classmethod of_length(
length,
index_dims=1,
default_value=None,
requires_grad=True,
)[source]

Alias for PerElementParameter.of_shape((length,), **kwargs).

Return type:

PerElementParameter

classmethod covalent_radii(scaling_factor=1.0)[source]

Create a PerElementParameter containing the covalent radii of each element in the periodic table.

Return type:

PerElementParameter

numel() int[source]

See torch.numel()

Return type:

int

class graph_pes.utils.nn.PerElementEmbedding(dim)[source]

A per-element equivalent of torch.nn.Embedding.

Parameters:

dim (int) – The length of each embedding vector.

Examples

>>> embedding = PerElementEmbedding(10)
>>> len(graph["atomic_numbers"])  # number of atoms in the graph
24
>>> embedding(graph["atomic_numbers"])
<tensor of shape (24, 10)>
class graph_pes.utils.nn.MLPConfig[source]

Bases: TypedDict

A TypedDict helper class for configuring an MLP.

Examples

Specify this in a config file:

mlp:
    hidden_depth: 3
    hidden_features: 64
    activation: SiLU
hidden_depth: int

The number of hidden layers in the MLP.

hidden_features: int

The number of features in the hidden layers.

activation: str

The activation function to use.

class graph_pes.utils.nn.MLP(layers, activation='CELU', activate_last=False, bias=True)[source]

A multi-layer perceptron model, alternating linear layers and activations.

Parameters:
  • layers (list[int]) – The number of nodes in each layer.

  • activation (str | torch.nn.Module) – The activation function to use: either a named activation function from torch.nn, or a torch.nn.Module instance.

  • activate_last (bool) – Whether to apply the activation function after the last linear layer.

  • bias (bool) – Whether to include bias terms in the linear layers.

Examples

>>> import torch
>>> from graph_pes.utils.nn import MLP
>>> model = MLP([10, 5, 1])
>>> model
MLP(10 → 5 → 1, activation=CELU())
>>> MLP([10, 5, 1], activation=torch.nn.ReLU())
MLP(10 → 5 → 1, activation=ReLU())
>>> MLP([10, 5, 1], activation="Tanh")
MLP(10 → 5 → 1, activation=Tanh())
forward(x)[source]

Perform a forward pass through the network.

Parameters:

x (Tensor) – The input to the network.

Return type:

Tensor

property input_size

The size of the input to the network.

property output_size

The size of the output of the network.

property layer_widths

The widths of the layers in the network.

classmethod from_config(config, input_features, output_features, bias=True)[source]

Create an MLP from a configuration.

Return type:

MLP