Models¶
- class graph_pes.GraphPESModel(cutoff, implemented_properties)[source]¶
-
All models implemented in
graph-pes
are subclasses ofGraphPESModel
.These models make predictions (via the
predict()
method) of the following properties:Key
Single graph
Batch of graphs
"local_energies"
(N,)
(N,)
"energy"
()
(M,)
"forces"
(N, 3)
(N, 3)
"stress"
(3, 3)
(M, 3, 3)
"virial"
(3, 3)
(M, 3, 3)
assuming an input of an
AtomicGraph
representing a single structure composed ofN
atoms, or anAtomicGraph
composed ofM
structures and containing a total ofN
atoms. (seeis_batch()
for more information about batching).Implementations must override the
forward()
method to generate a dictionary of predictions for the given graph. As a minimum, this must include a per-atom energy contribution ("local_energies"
).For any other properties not returned by the forward pass, the
predict()
method will automatically infer these properties from the local energies as required:"energy"
: as the sum of the local energies per structure."forces"
: as the negative gradient of the energy with respect to the atomic positions."stress"
: as the negative gradient of the energy with respect to a symmetric expansion of the unit cell, normalised by the cell volume."virial"
: as-stress * volume
.
For more details on how these are calculated, see Theory.
GraphPESModel
objects save various peices of extra metadata to thestate_dict
via theget_extra_state()
andset_extra_state()
methods. If you want to save additional extra state to thestate_dict
of your model, please implement theextra_state()
property and corresponding setter to ensure that you do not overwrite these extra metadata items.- Parameters:
cutoff (float) – The cutoff radius for the model.
implemented_properties (list[PropertyKey]) – The property predictions that the model implements in the forward pass. Must include at least
"local_energies"
.
- abstract forward(graph)[source]¶
The model’s forward pass. Generate all properties for the given graph that are in this model’s
implemented_properties
list.- Parameters:
graph (AtomicGraph) – The graph representation of the structure/s.
- Returns:
A dictionary mapping each implemented property to a tensor of predictions (see above for the expected shapes). Use
is_batch()
to check if the graph is batched in the forward pass.- Return type:
- predict(graph, properties)[source]¶
Generate (optionally batched) predictions for the given
properties
andgraph
.This method returns a dictionary mapping each requested
property
to a tensor of predictions, relying on the model’sforward()
implementation together withtorch.autograd.grad()
to automatically infer any missing properties.- Parameters:
graph (AtomicGraph) – The graph representation of the structure/s.
properties (list[Literal['local_energies', 'forces', 'energy', 'stress', 'virial']]) – The properties to predict. Can be any combination of
"energy"
,"forces"
,"stress"
,"virial"
, and"local_energies"
.
- Return type:
dict[Literal[‘local_energies’, ‘forces’, ‘energy’, ‘stress’, ‘virial’], ~torch.Tensor]
- pre_fit_all_components(graphs)[source]¶
Pre-fit the model, and all its components, to the training data.
Some models require pre-fitting to the training data to set certain parameters. For example, the
LennardJones
model uses the distribution of interatomic distances in the training data to set the length-scale parameter.In the
graph-pes-train
routine, this method is called before “normal” training begins (you can turn this off with a config option).This method does two things:
iterates over all the model’s
Module
components (inlcuding itself) and calls theirpre_fit()
method (if it exists - see for instanceLearnableOffset
for an example of a model-specific pre-fit method, andLocalEnergiesScaler
for an example of a component-specific pre-fit method).registers all the unique atomic numbers in the training data with all of the model’s
PerElementParameter
instances to ensure correct parameter counting.
If the model has already been pre-fitted, subsequent calls to
pre_fit_all_components()
will be ignored (and a warning will be raised).- Parameters:
graphs (Sequence[AtomicGraph]) – The training data.
- pre_fit(graphs)[source]¶
Override this method to perform additional pre-fitting steps.
See
LocalEnergiesScaler
orEnergyOffset
for examples of this.- Parameters:
graphs (AtomicGraph) – The training data.
- non_decayable_parameters()[source]¶
Return a list of parameters that should not be decayed during training.
By default, this method recurses over all available sub-modules and calls their
non_decayable_parameters()
(if it is defined).See
LocalEnergiesScaler
for an example of this.
- get_all_PES_predictions(graph)[source]¶
Get all the properties that the model can predict for the given
graph
.
- predict_local_energies(graph)[source]¶
Convenience method to predict just the local energies.
- Return type:
- final get_extra_state()[source]¶
Get the extra state of this instance. Please override the
extra_state()
property to add extra state here.
- final set_extra_state(state)[source]¶
Set the extra state of this instance using a dictionary mapping strings to values returned by the
extra_state()
property setter to add extra state here.
- ase_calculator(device=None, skin=1.0)[source]¶
Return an ASE calculator wrapping this model. See
GraphPESCalculator
for more information.- Parameters:
device (torch.device | str | None) – The device to use for the calculator. If
None
, the device of the model will be used.skin (float) – The skin to use for the neighbour list. If all atoms have moved less than half of this distance between calls to calculate, the neighbour list will be reused, saving (in some cases) significant computation time.
- Return type:
Loading Models¶
- graph_pes.models.load_model(path)[source]¶
Load a model from a file.
- Parameters:
path (str | pathlib.Path) – The path to the file.
- Returns:
The model.
- Return type:
Examples
Use this function to load an existing model for further training using
graph-pes-train
:model: +load_model: path: path/to/model.pt
To account for some new energy offset in your training data, you could do something like this: (see also
load_model_component()
)model: # add an offset to an existing model before fine-tuning offset: +LearnableOffset() many-body: +load_model: path: path/to/model.pt
- graph_pes.models.load_model_component(path, key)[source]¶
Load a component from an
AdditionModel
.- Parameters:
path (str | pathlib.Path) – The path to the file.
key (str) – The key to load.
- Returns:
The component.
- Return type:
Examples
Train on data with a new energy offset:
model: offset: +LearnableOffset() many-body: +load_model_component: path: path/to/model.pt key: many-body
Freezing Models¶
- class graph_pes.models.T¶
Type alias for
TypeVar("T", bound=torch.nn.Module)
.
- graph_pes.models.freeze_matching(model, pattern)[source]¶
Freeze all parameters that match the given pattern.
- Parameters:
- Returns:
The model.
- Return type:
Examples
Freeze all the parameters in the first layer of a MACE-MP0 model from
mace_mp()
(which have names of the form"model.interactions.0.<name>"
):model: +freeze_any_matching: model: +mace_mp() pattern: model\.interactions\.0\..*
- graph_pes.models.freeze_all_except(model, pattern)[source]¶
Freeze all parameters in a model except those matching a given pattern.
- Parameters:
- Returns:
The model.
- Return type:
Examples
Freeze all parameters in a MACE-MP0 model from
mace_mp()
except those in the read-out heads:model: +freeze_all_except: model: +mace_mp() pattern: model\.readouts.*