from __future__ import annotations
import torch
from torch import Tensor, nn
from graph_pes.atomic_graph import (
DEFAULT_CUTOFF,
AtomicGraph,
PropertyKey,
index_over_neighbours,
neighbour_distances,
neighbour_vectors,
number_of_edges,
sum_over_neighbours,
)
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.models.components.scaling import LocalEnergiesScaler
from graph_pes.utils.nn import (
MLP,
HaddamardProduct,
PerElementEmbedding,
UniformModuleList,
)
from .components.distances import (
CosineEnvelope,
DistanceExpansion,
ExponentialRBF,
get_distance_expansion,
)
[docs]
class TensorNet(GraphPESModel):
r"""
The `TensorNet <http://arxiv.org/abs/2306.06482>`_ architecture.
Citation:
.. code:: bibtex
@misc{Simeon-23-06,
title = {
TensorNet: Cartesian Tensor Representations for
Efficient Learning of Molecular Potentials
},
author = {Simeon, Guillem and {de Fabritiis}, Gianni},
year = {2023},
number = {arXiv:2306.06482},
}
Parameters
----------
cutoff
The cutoff radius to use for the model.
radial_features
The number of radial features to use for the model.
radial_expansion
The type of radial basis function to use for the model.
For more examples, see
:class:`~graph_pes.models.components.distances.DistanceExpansion`.
channels
The size of the embedding for each atom.
layers
The number of interaction layers to use for the model.
direct_force_predictions
Whether to predict forces directly. If ``True``, the model will
generate force predictions by passing the final
layer's node embeddings through a
:class:`~graph_pes.models.tensornet.VectorOutput` read out.
Otherwise, ``graph-pes`` automatically infers the forces as the
derivative of the energy with respect to the atomic positions.
Examples
--------
Configure a TensorNet model for use with ``graph-pes-train``:
.. code:: yaml
model:
+TensorNet:
radial_features: 8
radial_expansion: Bessel
channels: 32
cutoff: 5.0
"""
def __init__(
self,
cutoff: float = DEFAULT_CUTOFF,
radial_features: int = 32,
radial_expansion: str | type[DistanceExpansion] = ExponentialRBF,
channels: int = 32,
layers: int = 2,
direct_force_predictions: bool = False,
):
properties: list[PropertyKey] = ["local_energies"]
if direct_force_predictions:
properties.append("forces")
super().__init__(
cutoff=cutoff,
implemented_properties=properties,
)
self.embedding = Embedding(
radial_features, radial_expansion, channels, cutoff
)
self.interactions = UniformModuleList(
Interaction(radial_features, channels, cutoff)
for _ in range(layers)
)
self.energy_read_out = ScalarOutput(channels)
if direct_force_predictions:
self.force_read_out = VectorOutput(channels)
else:
self.force_read_out = None
self.scaler = LocalEnergiesScaler()
def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
X = self.embedding(graph) # (N, C, 3, 3)
for interaction in self.interactions:
# normalise -> interaction -> residual connection
X = X / (frobenius_norm(X)[..., None, None] + 1)
dX = interaction(X, graph)
X = X + dX
local_energies = self.energy_read_out(X).squeeze()
local_energies = self.scaler(local_energies, graph)
results: dict[PropertyKey, torch.Tensor] = {
"local_energies": local_energies
}
if self.force_read_out is not None:
results["forces"] = self.force_read_out(X)
return results
### components ###
class EdgeEmbedding(nn.Module):
r"""
Generates embeddings for each (directed) edge in the graph, incorporating
the species of each atom, and the vector between them.
1. generate initial edge embedding components:
* :math:`I_0^{(ij)} = \text{Id}`
* :math:`A_0^{(ij)} = \begin{pmatrix}
0 & \hat{r}_{ij}^z & - \hat{r}_{ij}^y \\
- \hat{r}_{ij}^z & 0 & \hat{r}_{ij}^x \\
\hat{r}_{ij}^y & - \hat{r}_{ij}^x & 0
\end{pmatrix}`
* :math:`S_0^{(ij)} = \hat{r}_{ij} \cdot \hat{r}_{ij}^T -
\text{Tr}(\hat{r}_{ij} \cdot \hat{r}_{ij}^T) \text{Id}`
2. generate an embedding of the ordered pair of species:
.. math::
h_Z^{(ij)} = f(z_i, z_j) = \text{Linear}(\text{embed}(z_i)
|| \text{embed}(z_j))
3. expand the edge vectors in to an exponential radial basis:
.. math::
h_{R,n}^{(ij)} = \exp(\beta_n \cdot (\exp(- r_{ij}) - \mu_k)^2)
4. combine all edge embeddings:
.. math::
X^{ij} = \phi(r_{ij}) \cdot h_Z^{(ij)} \cdot \left (
\text{Dense}(h_R^{(ij)}) \cdot I_0^{(ij)} +
\text{Dense}(h_R^{(ij)}) \cdot A_0^{(ij)} +
\text{Dense}(h_R^{(ij)}) \cdot S_0^{(ij)}
\right)
where :math:`\phi(r_{ij})` is the cosine envelope function.
"""
def __init__(
self,
radial_features: int,
radial_expansion: str | type[DistanceExpansion],
channels: int,
cutoff: float,
):
super().__init__()
self.z_embedding = PerElementEmbedding(channels)
self.z_map = nn.Linear(2 * channels, channels, bias=False)
expansion_klass = get_distance_expansion(radial_expansion)
self.distance_embedding = HaddamardProduct(
nn.Sequential(
expansion_klass(radial_features, cutoff),
nn.Linear(radial_features, 3 * channels),
),
CosineEnvelope(cutoff),
left_aligned=True,
)
def forward(self, graph: AtomicGraph) -> tuple[Tensor, Tensor, Tensor]:
C = self.z_embedding.dim()
E = number_of_edges(graph)
# 1. generate initial edge embedding components:
I_0, A_0, S_0 = self._initial_edge_embeddings(graph) # (E, 1, 3, 3)
# 2. encode atomic species of ordered neighbour pairs:
h_z_atom = self.z_embedding(graph.Z) # (N, C)
h_z_edge = h_z_atom[graph.neighbour_list] # (2, E, C)
h_z_edge = h_z_edge.permute(1, 0, 2).reshape(E, 2 * C)
h_z_edge = self.z_map(h_z_edge) # (E, C)
# 3. embed edge distances
h_r = self.distance_embedding(neighbour_distances(graph)) # (E, 3C)
# 4. combine information into coefficients
c = (h_r * h_z_edge.repeat(1, 3))[..., None, None] # (E, 3C, 1, 1)
c_I, c_A, c_S = torch.chunk(c, 3, dim=1) # (E, C, 1, 1)
return c_I * I_0, c_A * A_0, c_S * S_0 # 3x (E, C, 3, 3)
def _initial_edge_embeddings(
self, graph: AtomicGraph
) -> tuple[Tensor, Tensor, Tensor]:
E = number_of_edges(graph)
r_hat = neighbour_vectors(graph) / neighbour_distances(graph)[..., None]
eye = torch.eye(3, device=graph.Z.device)
I_ij = torch.repeat_interleave(eye[None, ...], E, dim=0) # (E, 3, 3)
A_ij = vector_to_skew_symmetric_matrix(r_hat)
S_ij = vector_to_symmetric_traceless_matrix(r_hat)
return (
I_ij.view(E, 1, 3, 3),
A_ij.view(E, 1, 3, 3),
S_ij.view(E, 1, 3, 3),
)
class Embedding(nn.Module):
"""
Embed the local environment of each atom into a ``(C, 3, 3)`` tensor.
"""
def __init__(
self,
radial_features: int,
radial_expansion: str | type[DistanceExpansion],
channels: int,
cutoff: float,
):
super().__init__()
self.edge_embedding = EdgeEmbedding(
radial_features, radial_expansion, channels, cutoff
)
self.layer_norm = nn.LayerNorm(channels)
self.mlp = MLP(
layers=[
channels,
2 * channels,
3 * channels,
],
activation=nn.SiLU(),
activate_last=True,
)
self.W_I = TensorLinear(channels, channels)
self.W_A = TensorLinear(channels, channels)
self.W_S = TensorLinear(channels, channels)
def forward(self, graph: AtomicGraph) -> Tensor:
# embed edges
I_edge, A_edge, S_edge = self.edge_embedding(graph) # (E, C, 3, 3)
# sum over neighbours to get atom embeddings
I_atom = sum_over_neighbours(I_edge, graph) # (N, C, 3, 3)
A_atom = sum_over_neighbours(A_edge, graph) # (N, C, 3, 3)
S_atom = sum_over_neighbours(S_edge, graph) # (N, C, 3, 3)
# generate coefficients from tensor representations
# (mixes irreps)
norms = frobenius_norm(I_atom + A_atom + S_atom) # (N, C)
norms = self.layer_norm(norms)
coefficients = self.mlp(norms)[..., None, None] # (N, 3C, 1, 1)
c_I, c_A, c_S = torch.chunk(coefficients, 3, dim=1) # (N, C, 1, 1)
# ...and combine with mixed coefficients with linear
# mixing of features
return (
c_I * self.W_I(I_atom)
+ c_A * self.W_A(A_atom)
+ c_S * self.W_S(S_atom)
) # (N, C, 3, 3)
class Interaction(nn.Module):
def __init__(
self,
radial_features: int,
channels: int,
cutoff: float,
):
super().__init__()
# unfortunately, we need to be explicit to satisfy torchscript
self.W_I_pre = TensorLinear(channels, channels)
self.W_A_pre = TensorLinear(channels, channels)
self.W_S_pre = TensorLinear(channels, channels)
self.W_I_post = TensorLinear(channels, channels)
self.W_A_post = TensorLinear(channels, channels)
self.W_S_post = TensorLinear(channels, channels)
self.distance_embedding = HaddamardProduct(
nn.Sequential(
ExponentialRBF(radial_features, cutoff),
MLP(
layers=[radial_features, 2 * channels, 3 * channels],
activation=nn.SiLU(),
activate_last=True,
),
),
CosineEnvelope(cutoff),
left_aligned=True,
)
def forward(self, X: Tensor, graph: AtomicGraph) -> Tensor:
# decompose matrix representations
I, A, S = decompose_matrix(X) # (N, C, 3, 3)
# update I, A, S
Y = self.W_I_pre(I) + self.W_A_pre(A) + self.W_S_pre(S) # (N, C, 3, 3)
# get coefficients from neighbour distances
c = self.distance_embedding(neighbour_distances(graph)) # (E, 3C)
f_I, f_A, f_S = torch.chunk(c[..., None, None], 3, dim=1)
# message passing
m_ij = (
f_I * index_over_neighbours(I, graph)
+ f_A * index_over_neighbours(A, graph)
+ f_S * index_over_neighbours(S, graph)
)
# total message
M_i = sum_over_neighbours(m_ij, graph) # (N, C, 3, 3)
# scalar/vector/tensor mixing
Y = Y @ M_i + M_i @ Y # (N, C, 3, 3)
# renormalise and decompose
norm = (frobenius_norm(Y) + 1)[..., None, None] # (N, C, 1, 1)
I, A, S = decompose_matrix(Y / norm) # (N, C, 3, 3)
# mix features again
Y = self.W_I_post(I) + self.W_A_post(A) + self.W_S_post(S)
return Y + torch.matrix_power(Y, 2) # (N, C, 3, 3)
[docs]
class ScalarOutput(nn.Module):
"""
A non-linear read-out function:
``X`` with shape ``(N, C, 3, 3)`` is decomposed into ``I``, ``A``, and
``S`` components. The concatenation of the Frobenius norms of these
components are passed through an MLP to generate a scalar.
"""
def __init__(self, channels: int):
super().__init__()
self.layer_norm = nn.LayerNorm(3 * channels)
self.mlp = MLP(
layers=[3 * channels, 2 * channels, 1],
activation=nn.SiLU(),
)
def forward(self, X: Tensor) -> Tensor:
"""X: (N, C, 3, 3) --> (N, 1)"""
I, A, S = decompose_matrix(X) # (N, C, 3, 3)
norm_I = frobenius_norm(I) # (N, C)
norm_A = frobenius_norm(A)
norm_S = frobenius_norm(S)
X = torch.cat((norm_I, norm_A, norm_S), dim=-1) # (N, 3C)
X = self.layer_norm(X)
return self.mlp(X) # (N, 1)
[docs]
class VectorOutput(nn.Module):
"""
A non-linear read-out function:
The ``A`` component of ``X`` with shape ``(N, C, 3, 3)`` is passed through
a linear layer, before extracting the ``x``, ``y``, and ``z`` components
of the resulting vector.
"""
def __init__(self, channels: int):
super().__init__()
self.linear = TensorLinear(channels, 1)
def forward(self, X: Tensor) -> Tensor:
"""X: (N, C, 3, 3) --> (N, 3)"""
_, A, _ = decompose_matrix(X) # (N, C, 3, 3)
A_final = self.linear(A).squeeze() # (N, 3, 3)
# A_final[b, c] = [
# [ 0, -vz, vy],
# [ vz, 0, -vx],
# [-vy, vx, 0],
# ]
x = A_final[..., 2, 1]
y = A_final[..., 0, 2]
z = A_final[..., 1, 0]
return torch.stack((x, y, z), dim=-1) # (N, 3)
### utils ###
def one_third_trace(X: Tensor) -> Tensor:
"""
Calculate the one third trace of a (optionally batched) matrix, ``X``, of
shape ``(...B, 3, 3)``.
"""
return X.diagonal(offset=0, dim1=-1, dim2=-2).mean(-1)
def frobenius_norm(X: Tensor) -> Tensor:
"""
Calculate the Frobenius norm of a (optionally batched) matrix, ``X``, of
shape ``(...B, 3, 3)``.
"""
return (X**2).sum((-2, -1))
def decompose_matrix(X: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""
Take a (optionally batched) matrix, ``X``, of shape ``(...B, 3, 3)`` and
decompose it into irreducible components, ``I``, ``A``, and ``S``:
* ``I[b] = 1/3 * trace(X[b]) * Id``
* ``A[b] = 0.5 * (X[b] - X[b].T)``
* ``S[b] = 0.5 * (X[b] + X[b].T) - I[b]``
where ``Id`` is the ``3x3`` identity matrix and ``b`` is some
batch dimension(s)
Parameters
----------
x
The matrix to decompose, of shape ``(...B, 3, 3)``.
"""
trace = one_third_trace(X)[..., None, None] # (...B, 1, 1)
I = trace * torch.eye(3, 3, device=X.device, dtype=X.dtype) # (...B, 3, 3)
A = 0.5 * (X - X.transpose(-2, -1))
S = 0.5 * (X + X.transpose(-2, -1)) - I
return I, A, S # 3x (...B, 3, 3)
def vector_to_skew_symmetric_matrix(v: Tensor) -> Tensor:
"""
Creates a skew-symmetric tensor from a (optionally batched) vector.
v: ([B], 3) --> sst: ([B], 3, 3)
sst[b] = [
[ 0, -v[b].z, v[b].y],
[ v[b].z, 0, -v[b].x],
[-v[b].y, v[b].x, 0],
]
"""
x, y, z = v.unbind(dim=-1)
zero = torch.zeros_like(x)
tensor = torch.stack(
(
zero,
-z,
y,
z,
zero,
-x,
-y,
x,
zero,
),
dim=1,
) # (B, 9)
return tensor.reshape(tensor.shape[:-1] + (3, 3))
def vector_to_symmetric_traceless_matrix(v: Tensor) -> Tensor:
"""
Creates a symmetric traceless matrix from a vector.
v: (..., 3) --> stm: (..., 3, 3)
stm[...b] = 0.5 * (v[...b] v[...b].T) - 1/3 * trace(v[...b] v[...b].T) * Id
"""
v_vT = torch.matmul(v.unsqueeze(-1), v.unsqueeze(-2))
Id = torch.eye(3, 3, device=v_vT.device, dtype=v_vT.dtype)
I = one_third_trace(v_vT)[..., None, None] * Id
return 0.5 * (v_vT + v_vT.transpose(-2, -1)) - I
class TensorLinear(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=False)
def forward(self, x: Tensor) -> Tensor:
"""x: (N, in, 3, 3) --> (N, out, 3, 3)"""
return self.linear(x.transpose(-1, -3)).transpose(-1, -3)
def __repr__(self):
_in, _out = self.linear.in_features, self.linear.out_features
return (
f"{self.__class__.__name__}("
f"[N, {_in}, 3, 3] --> [N, {_out}, 3, 3])"
)