e3nn Helpers

class graph_pes.models.e3nn.utils.LinearReadOut(input_irreps, output_irrep='0e')[source]

Map a set of features with arbitrary irreps to a single irrep with a single feature using a linear layer.

Parameters:
  • input_irreps (str or o3.Irreps) – The irreps of the input features.

  • output_irrep (str, optional) – The irrep of the output feature. Defaults to “0e”.

Examples

Map an embedding to a scalar output:

>>> LinearReadOut("16x0e+16x1o+16x2e")
LinearReadOut(16x0e+16x1o+16x2e -> 1x0e | 16 weights)

Map an embedding to a vector output:

>>> LinearReadOut("16x0e+16x1o+16x2e", "1o")
LinearReadOut(16x0e+16x1o+16x2e -> 1x1o | 16 weights)
class graph_pes.models.e3nn.utils.NonLinearReadOut(
input_irreps,
output_irrep='0e',
hidden_dim=None,
activation=None,
)[source]

Non-linear readout layer for equivariant neural networks.

This class implements a non-linear readout layer that takes input features with arbitrary irreps and produces a scalar output. It uses a linear layer followed by an activation function and another linear layer to produce the final scalar output.

Parameters:
  • input_irreps (str) – The irreps of the input features.

  • output_irrep (str, optional) – The irrep of the output feature. Defaults to “0e”.

  • hidden_dim (int, optional) – The dimension of the hidden layer. If None, it defaults to the number of scalar output irreps in the input.

  • activation (str or torch.nn.Module, optional) – The activation function to use. Can be specified as a string (e.g., ‘ReLU’, ‘SiLU’) or as a torch.nn.Module. Defaults to SiLU for even output irreps and Tanh for odd output irreps. Care must be taken to ensure that the activation is suitable for the target irreps!

Examples

Map an embedding to a scalar output:

>>> NonLinearReadOut("16x0e+16x1o+16x2e")
NonLinearReadOut(
  (0): Linear(16x0e+16x1o+16x2e -> 16x0e | 256 weights)
  (1): SiLU()
  (2): Linear(16x0e -> 1x0e | 16 weights)
)

Map an embedding to a vector output:

>>> NonLinearReadOut("16x0e+16x1o+16x2e", "1o")
NonLinearReadOut(
  (0): Linear(16x0e+16x1o+16x2e -> 16x1o | 256 weights)
  (1): Tanh()
  (2): Linear(16x1o -> 1x1o | 16 weights)
)