from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Final, Sequence, final
import torch
from ase.data import chemical_symbols
from torch import nn
from graph_pes.atomic_graph import (
AtomicGraph,
PropertyKey,
has_cell,
is_batch,
sum_per_structure,
to_batch,
trim_edges,
)
from graph_pes.data.datasets import GraphDataset
from graph_pes.utils.logger import logger
from .utils.misc import differentiate, differentiate_all
from .utils.nn import PerElementParameter
if TYPE_CHECKING:
from graph_pes.utils.calculator import GraphPESCalculator
[docs]
class GraphPESModel(nn.Module, ABC):
r"""
All models implemented in ``graph-pes`` are subclasses of
:class:`~graph_pes.GraphPESModel`.
These models make predictions (via the
:meth:`~graph_pes.GraphPESModel.predict` method) of the
following properties:
.. list-table::
:header-rows: 1
* - Key
- Single graph
- Batch of graphs
* - :code:`"local_energies"`
- :code:`(N,)`
- :code:`(N,)`
* - :code:`"energy"`
- :code:`()`
- :code:`(M,)`
* - :code:`"forces"`
- :code:`(N, 3)`
- :code:`(N, 3)`
* - :code:`"stress"`
- :code:`(3, 3)`
- :code:`(M, 3, 3)`
* - :code:`"virial"`
- :code:`(3, 3)`
- :code:`(M, 3, 3)`
assuming an input of an :class:`~graph_pes.AtomicGraph` representing a
single structure composed of ``N`` atoms, or an
:class:`~graph_pes.AtomicGraph` composed of ``M`` structures and containing
a total of ``N`` atoms. (see :func:`~graph_pes.atomic_graph.is_batch` for
more information about batching).
Implementations must override the
:meth:`~graph_pes.GraphPESModel.forward` method to generate a
dictionary of predictions for the given graph. As a minimum, this must
include a per-atom energy contribution (``"local_energies"``).
For any other properties not returned by the forward pass,
the :meth:`~graph_pes.GraphPESModel.predict` method will automatically
infer these properties from the local energies as required:
* ``"energy"``: as the sum of the local energies per structure.
* ``"forces"``: as the negative gradient of the energy with respect to the
atomic positions.
* ``"stress"``: as the negative gradient of the energy with respect to a
symmetric expansion of the unit cell, normalised by the cell volume.
* ``"virial"``: as ``-stress * volume``.
For more details on how these are calculated, see :doc:`../theory`.
:class:`~graph_pes.GraphPESModel` objects save various peices of extra
metadata to the ``state_dict`` via the
:meth:`~graph_pes.GraphPESModel.get_extra_state` and
:meth:`~graph_pes.GraphPESModel.set_extra_state` methods.
If you want to save additional extra state to the ``state_dict`` of your
model, please implement the :meth:`~graph_pes.GraphPESModel.extra_state`
property and corresponding setter to ensure that you do not overwrite
these extra metadata items.
Parameters
----------
cutoff
The cutoff radius for the model.
implemented_properties
The property predictions that the model implements in the forward pass.
Must include at least ``"local_energies"``.
"""
def __init__(
self,
cutoff: float,
implemented_properties: list[PropertyKey],
):
super().__init__()
self._GRAPH_PES_VERSION: Final[str] = "0.0.25"
self.cutoff: torch.Tensor
self.register_buffer("cutoff", torch.tensor(cutoff))
self._has_been_pre_fit: torch.Tensor
self.register_buffer("_has_been_pre_fit", torch.tensor(0))
self.implemented_properties = implemented_properties
if "local_energies" not in implemented_properties:
raise ValueError(
'All GraphPESModel\'s must implement a "local_energies" '
"prediction."
)
[docs]
@abstractmethod
def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
"""
The model's forward pass. Generate all properties for the given graph
that are in this model's ``implemented_properties`` list.
Parameters
----------
graph
The graph representation of the structure/s.
Returns
-------
dict[PropertyKey, torch.Tensor]
A dictionary mapping each implemented property to a tensor of
predictions (see above for the expected shapes). Use
:func:`~graph_pes.atomic_graph.is_batch` to check if the
graph is batched in the forward pass.
"""
...
[docs]
def predict(
self,
graph: AtomicGraph,
properties: list[PropertyKey],
) -> dict[PropertyKey, torch.Tensor]:
"""
Generate (optionally batched) predictions for the given
``properties`` and ``graph``.
This method returns a dictionary mapping each requested
``property`` to a tensor of predictions, relying on the model's
:meth:`~graph_pes.GraphPESModel.forward` implementation
together with :func:`torch.autograd.grad` to automatically infer any
missing properties.
Parameters
----------
graph
The graph representation of the structure/s.
properties
The properties to predict. Can be any combination of
``"energy"``, ``"forces"``, ``"stress"``, ``"virial"``, and
``"local_energies"``.
"""
# before anything, remove unnecessary edges:
graph = trim_edges(graph, self.cutoff.item())
# check to see if we need to infer any properties
infer_forces = (
"forces" in properties
and "forces" not in self.implemented_properties
)
infer_stress_information = (
# we need to infer stress information if we're asking for
# the stress or the virial and the model doesn't
# implement either of them direct
("stress" in properties or "virial" in properties)
and "stress" not in self.implemented_properties
and "virial" not in self.implemented_properties
)
if infer_stress_information and not has_cell(graph):
raise ValueError("Can't predict stress without cell information.")
infer_energy = (
any([infer_stress_information, infer_forces])
or "energy" not in self.implemented_properties
)
existing_positions = graph.R
existing_cell = graph.cell
# inference specific set up
if infer_stress_information:
# See About>Theory in the graph-pes for an explanation of the
# maths behind this.
#
# The stress tensor is the gradient of the total energy wrt
# a symmetric expansion of the structure (i.e. that acts on
# both the cell and the atomic positions).
#
# F. Knuth et al. All-electron formalism for total energy strain
# derivatives and stress tensor components for numeric atom-centered
# orbitals. Computer Physics Communications 190, 33–50 (2015).
change_to_cell = torch.zeros_like(existing_cell)
change_to_cell.requires_grad_(True)
symmetric_change = 0.5 * (
change_to_cell + change_to_cell.transpose(-1, -2)
) # (n_structures, 3, 3) if batched, else (3, 3)
scaling = (
torch.eye(3, device=existing_cell.device) + symmetric_change
)
# torchscript annoying-ness:
graph_batch = graph.batch
if graph_batch is not None:
scaling_per_atom = torch.index_select(
scaling,
dim=0,
index=graph_batch,
) # (n_atoms, 3, 3)
# to go from (N, 3) @ (N, 3, 3) -> (N, 3), we need un/squeeze:
# (N, 1, 3) @ (N, 3, 3) -> (N, 1, 3) -> (N, 3)
new_positions = (
graph.R.unsqueeze(-2) @ scaling_per_atom
).squeeze()
# (M, 3, 3) @ (M, 3, 3) -> (M, 3, 3)
new_cell = existing_cell @ scaling
else:
# (N, 3) @ (3, 3) -> (N, 3)
new_positions = graph.R @ scaling
new_cell = existing_cell @ scaling
# change to positions will be a tensor of all 0's, but will allow
# gradients to flow backwards through the energy calculation
# and allow us to calculate the stress tensor as the gradient
# of the energy wrt the change in cell.
graph = AtomicGraph( # can't use _replace here due to TorchScript
Z=graph.Z,
R=new_positions, # <- new positions
cell=new_cell, # <- new cell
neighbour_list=graph.neighbour_list,
neighbour_cell_offsets=graph.neighbour_cell_offsets,
properties=graph.properties,
other=graph.other,
cutoff=graph.cutoff,
batch=graph.batch,
ptr=graph.ptr,
)
else:
change_to_cell = torch.zeros_like(graph.cell)
if infer_forces:
graph.R.requires_grad_(True)
# get the implemented properties
predictions = self(graph)
if infer_energy:
if "local_energies" not in predictions:
raise ValueError("Can't infer energy without local energies.")
predictions["energy"] = sum_per_structure(
predictions["local_energies"],
graph,
)
# use the autograd machinery to auto-magically
# calculate forces and stress from the energy
cell_volume = torch.det(graph.cell)
if is_batch(graph):
cell_volume = cell_volume.view(-1, 1, 1)
# ugly triple if loops to be efficient with autograd
# while also Torchscript compatible
if infer_forces and infer_stress_information:
assert "energy" in predictions
dE_dR, dE_dC = differentiate_all(
predictions["energy"],
[graph.R, change_to_cell],
keep_graph=self.training,
)
predictions["forces"] = -dE_dR
predictions["virial"] = -dE_dC
predictions["stress"] = dE_dC / cell_volume
elif infer_forces:
assert "energy" in predictions
dE_dR = differentiate(
predictions["energy"],
graph.R,
keep_graph=self.training,
)
predictions["forces"] = -dE_dR
elif infer_stress_information:
assert "energy" in predictions
dE_dC = differentiate(
predictions["energy"],
change_to_cell,
keep_graph=self.training,
)
predictions["virial"] = -dE_dC
predictions["stress"] = dE_dC / cell_volume
# finally, we might not have needed autograd to infer stress/virial
# if the other was implemented on the base class:
if not infer_stress_information:
if (
"stress" in properties
and "stress" not in self.implemented_properties
):
predictions["stress"] = -predictions["virial"] / cell_volume
if (
"virial" in properties
and "virial" not in self.implemented_properties
):
predictions["virial"] = -predictions["stress"] * cell_volume
# put things back to how they were before
graph = AtomicGraph( # can't use _replace here due to TorchScript
Z=graph.Z,
R=existing_positions, # <- old positions
cell=existing_cell, # <- old cell
neighbour_list=graph.neighbour_list,
neighbour_cell_offsets=graph.neighbour_cell_offsets,
properties=graph.properties,
other=graph.other,
cutoff=graph.cutoff,
batch=graph.batch,
ptr=graph.ptr,
)
# make sure we don't leave auxiliary predictions
# e.g. local_energies if we only asked for energy
# or energy if we only asked for forces
predictions = {prop: predictions[prop] for prop in properties}
# tidy up if in eval mode
if not self.training:
predictions = {k: v.detach() for k, v in predictions.items()}
# return the output
return predictions # type: ignore
[docs]
@torch.no_grad()
def pre_fit_all_components(
self,
graphs: Sequence[AtomicGraph],
):
"""
Pre-fit the model, and all its components, to the training data.
Some models require pre-fitting to the training data to set certain
parameters. For example, the :class:`~graph_pes.models.pairwise.LennardJones`
model uses the distribution of interatomic distances in the training data
to set the length-scale parameter.
In the ``graph-pes-train`` routine, this method is called before
"normal" training begins (you can turn this off with a config option).
This method does two things:
1. iterates over all the model's :class:`~torch.nn.Module` components
(inlcuding itself) and calls their :meth:`pre_fit` method (if it exists -
see for instance :class:`~graph_pes.models.LearnableOffset` for
an example of a model-specific pre-fit method, and
:class:`~graph_pes.models.components.scaling.LocalEnergiesScaler` for
an example of a component-specific pre-fit method).
2. registers all the unique atomic numbers in the training data with
all of the model's :class:`~graph_pes.utils.nn.PerElementParameter`
instances to ensure correct parameter counting.
If the model has already been pre-fitted, subsequent calls to
:meth:`pre_fit_all_components` will be ignored (and a warning will be raised).
Parameters
----------
graphs
The training data.
""" # noqa: E501
model_name = self.__class__.__name__
logger.debug(f"Attempting to pre-fit {model_name}")
# 1. get the graphs as a single batch
if isinstance(graphs, GraphDataset):
graphs = list(graphs)
graph_batch = to_batch(graphs)
# 2a. if the graph has already been pre-fitted: warn
if self._has_been_pre_fit:
model_name = self.__class__.__name__
logger.warning(
f"This model ({model_name}) has already been pre-fitted. "
"This, and any subsequent, call to pre_fit_all_components will "
"be ignored.",
stacklevel=2,
)
# 2b. if the model has not been pre-fitted: pre-fit
else:
if len(graphs) > 10_000:
logger.warning(
f"Pre-fitting on a large dataset ({len(graphs):,} graphs). "
"This may take some time. Consider using a smaller, "
"representative collection of structures for pre-fitting. "
"Set ``max_n_pre_fit`` in your config, or "
"see GraphDataset.sample() for more information.",
stacklevel=2,
)
self.pre_fit(graph_batch)
# pre-fit any sub-module with a pre_fit method
for module in self.modules():
if hasattr(module, "pre_fit") and module is not self:
module.pre_fit(graph_batch)
self._has_been_pre_fit = torch.tensor(1)
# 3. finally, register all per-element parameters (no harm in doing this
# multiple times)
for param in self.parameters():
if isinstance(param, PerElementParameter):
param.register_elements(torch.unique(graph_batch.Z).tolist())
[docs]
def pre_fit(self, graphs: AtomicGraph) -> None:
"""
Override this method to perform additional pre-fitting steps.
See :class:`~graph_pes.models.components.scaling.LocalEnergiesScaler` or
:class:`~graph_pes.models.offsets.EnergyOffset` for examples of this.
Parameters
----------
graphs
The training data.
"""
[docs]
def non_decayable_parameters(self) -> list[torch.nn.Parameter]:
"""
Return a list of parameters that should not be decayed during training.
By default, this method recurses over all available sub-modules
and calls their :meth:`non_decayable_parameters` (if it is defined).
See :class:`~graph_pes.models.components.scaling.LocalEnergiesScaler`
for an example of this.
"""
found = []
for module in self.modules():
if module is self:
continue
if hasattr(module, "non_decayable_parameters"):
found.extend(module.non_decayable_parameters())
return found
# add type hints for mypy etc.
def __call__(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
return super().__call__(graph)
[docs]
def get_all_PES_predictions(
self, graph: AtomicGraph
) -> dict[PropertyKey, torch.Tensor]:
"""
Get all the properties that the model can predict
for the given ``graph``.
"""
properties: list[PropertyKey] = [
"energy",
"forces",
"local_energies",
]
if has_cell(graph):
properties.extend(["stress", "virial"])
return self.predict(graph, properties)
[docs]
def predict_energy(self, graph: AtomicGraph) -> torch.Tensor:
"""Convenience method to predict just the energy."""
return self.predict(graph, ["energy"])["energy"]
[docs]
def predict_forces(self, graph: AtomicGraph) -> torch.Tensor:
"""Convenience method to predict just the forces."""
return self.predict(graph, ["forces"])["forces"]
[docs]
def predict_stress(self, graph: AtomicGraph) -> torch.Tensor:
"""Convenience method to predict just the stress."""
return self.predict(graph, ["stress"])["stress"]
[docs]
def predict_virial(self, graph: AtomicGraph) -> torch.Tensor:
"""Convenience method to predict just the virial."""
return self.predict(graph, ["virial"])["virial"]
[docs]
def predict_local_energies(self, graph: AtomicGraph) -> torch.Tensor:
"""Convenience method to predict just the local energies."""
return self.predict(graph, ["local_energies"])["local_energies"]
@torch.jit.unused
@property
def elements_seen(self) -> list[str]:
"""The elements that the model has seen during training."""
Zs = set()
for param in self.parameters():
if isinstance(param, PerElementParameter):
Zs.update(param._accessed_Zs)
return [chemical_symbols[Z] for Z in sorted(Zs)]
@torch.jit.unused
@property
def device(self) -> torch.device:
return self.cutoff.device
@torch.jit.unused
@property
def extra_state(self) -> Any:
"""
Override this property to add extra state to the model's
``state_dict``.
"""
return {}
@torch.jit.unused
@extra_state.setter
def extra_state(self, state: Any) -> None:
"""
Set the extra state of this instance using a value returned by the
:meth:`~graph_pes.GraphPESModel.extra_state` property.
"""
pass
[docs]
@torch.jit.unused
def ase_calculator(
self, device: torch.device | str | None = None, skin: float = 1.0
) -> "GraphPESCalculator":
"""
Return an ASE calculator wrapping this model. See
:class:`~graph_pes.utils.calculator.GraphPESCalculator` for more
information.
Parameters
----------
device
The device to use for the calculator. If ``None``, the device of the
model will be used.
skin
The skin to use for the neighbour list. If all atoms have moved less
than half of this distance between calls to `calculate`, the
neighbour list will be reused, saving (in some cases) significant
computation time.
"""
from graph_pes.utils.calculator import GraphPESCalculator
return GraphPESCalculator(self, device=device, skin=skin)