TensorNet

class graph_pes.models.TensorNet(
cutoff=5.0,
radial_features=32,
radial_expansion=<class 'graph_pes.models.components.distances.ExponentialRBF'>,
channels=32,
layers=2,
direct_force_predictions=False,
)[source]

Bases: GraphPESModel

The TensorNet architecture.

Citation:

@misc{Simeon-23-06,
    title = {
        TensorNet: Cartesian Tensor Representations for
        Efficient Learning of Molecular Potentials
    },
    author = {Simeon, Guillem and {de Fabritiis}, Gianni},
    year = {2023},
    number = {arXiv:2306.06482},
}
Parameters:
  • cutoff (float) – The cutoff radius to use for the model.

  • radial_features (int) – The number of radial features to use for the model.

  • radial_expansion (str | type[DistanceExpansion]) – The type of radial basis function to use for the model. For more examples, see DistanceExpansion.

  • channels (int) – The size of the embedding for each atom.

  • layers (int) – The number of interaction layers to use for the model.

  • direct_force_predictions (bool) – Whether to predict forces directly. If True, the model will generate force predictions by passing the final layer’s node embeddings through a VectorOutput read out. Otherwise, graph-pes automatically infers the forces as the derivative of the energy with respect to the atomic positions.

Examples

Configure a TensorNet model for use with graph-pes-train:

model:
  +TensorNet:
    radial_features: 8
    radial_expansion: Bessel
    channels: 32
    cutoff: 5.0

Components

Below, we use the notation as taken from the TensorNet paper.

class graph_pes.models.tensornet.ScalarOutput(channels)[source]

A non-linear read-out function:

X with shape (N, C, 3, 3) is decomposed into I, A, and S components. The concatenation of the Frobenius norms of these components are passed through an MLP to generate a scalar.

class graph_pes.models.tensornet.VectorOutput(channels)[source]

A non-linear read-out function:

The A component of X with shape (N, C, 3, 3) is passed through a linear layer, before extracting the x, y, and z components of the resulting vector.