Models¶
- class graph_pes.GraphPESModel(cutoff, implemented_properties, three_body_cutoff=None)[source]¶
-
All models implemented in
graph-pesare subclasses ofGraphPESModel.These models make predictions (via the
predict()method) of the following properties:Key
Single graph
Batch of graphs
Units
"local_energies"(N,)(N,)[energy]"energy"()(M,)[energy]"forces"(N, 3)(N, 3)[energy / length]"stress"(3, 3)(M, 3, 3)[energy / length^3]"virial"(3, 3)(M, 3, 3)[energy]assuming an input of an
AtomicGraphrepresenting a single structure composed ofNatoms, or anAtomicGraphcomposed ofMstructures and containing a total ofNatoms. (seeis_batch()for more information about batching).Note that
graph-pesmakes no assumptions as to the actual units of theenergyandlengthquantities - these will depend on the labels the model has been trained on (e.g. could beeVandÅ,kcal/molandnmor evenJandm).Implementations must override the
forward()method to generate a dictionary of predictions for the given graph. As a minimum, this must include a per-atom energy contribution ("local_energies").For any other properties not returned by the forward pass, the
predict()method will automatically infer these properties from the local energies as required:"energy": as the sum of the local energies per structure."forces": as the negative gradient of the energy with respect to the atomic positions."stress": as the negative gradient of the energy with respect to a symmetric expansion of the unit cell, normalised by the cell volume. In keeping with convention, a negative stress indicates the system is under static compression (wants to expand)."virial": as-stress * volume. A negative virial indicates the system is under static tension (wants to contract).
For more details on how these are calculated, see Theory.
GraphPESModelobjects save various pieces of extra metadata to thestate_dictvia theget_extra_state()andset_extra_state()methods. If you want to save additional extra state to thestate_dictof your model, please implement theextra_state()property and corresponding setter to ensure that you do not overwrite these extra metadata items.- Parameters:
cutoff (float) – The cutoff radius for the model.
implemented_properties (list[PropertyKey]) – The property predictions that the model implements in the forward pass. Must include at least
"local_energies".three_body_cutoff (float | None) – The cutoff radius for this model’s three-body interactions, if applicable.
- abstract forward(graph)[source]¶
The model’s forward pass. Generate all properties for the given graph that are in this model’s
implemented_propertieslist.- Parameters:
graph (AtomicGraph) – The graph representation of the structure/s.
- Returns:
A dictionary mapping each implemented property to a tensor of predictions (see above for the expected shapes). Use
is_batch()to check if the graph is batched in the forward pass.- Return type:
- predict(graph, properties)[source]¶
Generate (optionally batched) predictions for the given
propertiesandgraph.This method returns a dictionary mapping each requested
propertyto a tensor of predictions, relying on the model’sforward()implementation together withtorch.autograd.grad()to automatically infer any missing properties.- Parameters:
graph (AtomicGraph) – The graph representation of the structure/s.
properties (list[Literal['local_energies', 'forces', 'energy', 'stress', 'virial']]) – The properties to predict. Can be any combination of
"energy","forces","stress","virial", and"local_energies".
- Return type:
dict[Literal[‘local_energies’, ‘forces’, ‘energy’, ‘stress’, ‘virial’], ~torch.Tensor]
- pre_fit_all_components(graphs)[source]¶
Pre-fit the model, and all its components, to the training data.
Some models require pre-fitting to the training data to set certain parameters. For example, the
LennardJonesmodel uses the distribution of interatomic distances in the training data to set the length-scale parameter.In the
graph-pes-trainroutine, this method is called before “normal” training begins (you can turn this off with a config option).This method does two things:
iterates over all the model’s
Modulecomponents (including itself) and calls theirpre_fit()method (if it exists - see for instanceLearnableOffsetfor an example of a model-specific pre-fit method, andLocalEnergiesScalerfor an example of a component-specific pre-fit method).registers all the unique atomic numbers in the training data with all of the model’s
PerElementParameterinstances to ensure correct parameter counting.
If the model has already been pre-fitted, subsequent calls to
pre_fit_all_components()will be ignored (and a warning will be raised).- Parameters:
graphs (Sequence[AtomicGraph]) – The training data.
- pre_fit(graphs)[source]¶
Override this method to perform additional pre-fitting steps.
See
LocalEnergiesScalerorEnergyOffsetfor examples of this.- Parameters:
graphs (AtomicGraph) – The training data.
- non_decayable_parameters()[source]¶
Return a list of parameters that should not be decayed during training.
By default, this method recurses over all available sub-modules and calls their
non_decayable_parameters()(if it is defined).See
LocalEnergiesScalerfor an example of this.
- get_all_PES_predictions(graph)[source]¶
Get all the properties that the model can predict for the given
graph.
- predict_local_energies(graph)[source]¶
Convenience method to predict just the local energies.
- Return type:
- final get_extra_state()[source]¶
Get the extra state of this instance. Please override the
extra_state()property to add extra state here.
- final set_extra_state(state)[source]¶
Set the extra state of this instance using a dictionary mapping strings to values returned by the
extra_state()property setter to add extra state here.
- ase_calculator(
- device=None,
- skin=1.0,
- cache_threebody=True,
Return an ASE calculator wrapping this model. See
GraphPESCalculatorfor more information.- Parameters:
device (torch.device | str | None) – The device to use for the calculator. If
None, the device of the model will be used.skin (float) – The skin to use for the neighbour list. If all atoms have moved less than half of this distance between calls to calculate, the neighbour list will be reused, saving (in some cases) significant computation time.
cache_threebody (bool) – Whether to cache the three-body neighbour list entries. In many cases, this can accelerate MD simulations by avoiding these quite expensive recalculations. Tuning the
skinparameter is important to optimise the trade-off between less frequent but more expensive neighbour list recalculations. This options is ignored if the model does not use three-body interactions.
- Return type:
- torch_sim_model(
- device=None,
- dtype=torch.float64,
- *,
- compute_forces=True,
- compute_stress=True,
Return a model suitable for use with the torch_sim package.
Internally, we set this model to evaluation mode, and wrap it in a class that is suitable for use with the
torch_simpackage.- Parameters:
device (torch.device | None) – The device to use for the model. If
None, the model will be placed on the best device available.dtype (torch.dtype) – The dtype to use for the model.
compute_forces (bool) – Whether to compute forces. Set this to
Falseif you only need to generate energies within thetorch_simintegrator.compute_stress (bool) – Whether to compute stress. Set this to
Falseif you don’t need stress information from the model within thetorch_simintegrator.
Loading Models¶
- graph_pes.models.load_model(path)[source]¶
Load a model from a file.
- Parameters:
path (str | pathlib.Path) – The path to the file.
- Returns:
The model.
- Return type:
Examples
Use this function to load an existing model for further training using
graph-pes-train:model: +load_model: path: path/to/model.pt
See fine-tuning for more details.
- graph_pes.models.load_model_component(path, key)[source]¶
Load a component from an
AdditionModel.- Parameters:
path (str | pathlib.Path) – The path to the file.
key (str) – The key to load.
- Returns:
The component.
- Return type:
Freezing Models¶
- class graph_pes.models.T¶
Type alias for
TypeVar("T", bound=torch.nn.Module).
- graph_pes.models.freeze_matching(model, pattern)[source]¶
Freeze all parameters that match the given pattern.
- Parameters:
- Returns:
The model.
- Return type:
Examples
Freeze all the parameters in the first layer of a MACE-MP0 model from
mace_mp()(which have names of the form"model.interactions.0.<name>"):model: +freeze_any_matching: model: +mace_mp() pattern: model\.interactions\.0\..*
- graph_pes.models.freeze_all_except(model, pattern)[source]¶
Freeze all parameters in a model except those matching a given pattern.
- Parameters:
- Returns:
The model.
- Return type:
Examples
Freeze all parameters in a MACE-MP0 model from
mace_mp()except those in the read-out heads:model: +freeze_all_except: model: +mace_mp() pattern: model\.readouts.*
Available Models¶
Unit Conversion¶
- class graph_pes.models.UnitConverter(model, energy_to_eV, length_to_A)[source]¶
Bases:
GraphPESModelA wrapper that converts the units of the energy, forces and stress predictions of an underlying model.
- Parameters:
model (GraphPESModel) – The underlying model.
energy_to_eV (float) – The conversion factor for energy, such that the
model.predict_energy(graph) * energy_to_eVgives the energy prediction in eV.length_to_A (float) – The conversion factor for length, such that the
model.predict_forces(graph) * (energy_to_eV / length_to_A)gives the force prediction in eV/Å.