Models

class graph_pes.GraphPESModel(cutoff, implemented_properties, three_body_cutoff=None)[source]

Bases: Module, ABC

All models implemented in graph-pes are subclasses of GraphPESModel.

These models make predictions (via the predict() method) of the following properties:

Key

Single graph

Batch of graphs

Units

"local_energies"

(N,)

(N,)

[energy]

"energy"

()

(M,)

[energy]

"forces"

(N, 3)

(N, 3)

[energy / length]

"stress"

(3, 3)

(M, 3, 3)

[energy / length^3]

"virial"

(3, 3)

(M, 3, 3)

[energy]

assuming an input of an AtomicGraph representing a single structure composed of N atoms, or an AtomicGraph composed of M structures and containing a total of N atoms. (see is_batch() for more information about batching).

Note that graph-pes makes no assumptions as to the actual units of the energy and length quantities - these will depend on the labels the model has been trained on (e.g. could be eV and Å, kcal/mol and nm or even J and m).

Implementations must override the forward() method to generate a dictionary of predictions for the given graph. As a minimum, this must include a per-atom energy contribution ("local_energies").

For any other properties not returned by the forward pass, the predict() method will automatically infer these properties from the local energies as required:

  • "energy": as the sum of the local energies per structure.

  • "forces": as the negative gradient of the energy with respect to the atomic positions.

  • "stress": as the negative gradient of the energy with respect to a symmetric expansion of the unit cell, normalised by the cell volume. In keeping with convention, a negative stress indicates the system is under static compression (wants to expand).

  • "virial": as -stress * volume. A negative virial indicates the system is under static tension (wants to contract).

For more details on how these are calculated, see Theory.

GraphPESModel objects save various peices of extra metadata to the state_dict via the get_extra_state() and set_extra_state() methods. If you want to save additional extra state to the state_dict of your model, please implement the extra_state() property and corresponding setter to ensure that you do not overwrite these extra metadata items.

Parameters:
  • cutoff (float) – The cutoff radius for the model.

  • implemented_properties (list[PropertyKey]) – The property predictions that the model implements in the forward pass. Must include at least "local_energies".

  • three_body_cutoff (float | None) – The cutoff radius for this model’s three-body interactions, if applicable.

abstract forward(graph)[source]

The model’s forward pass. Generate all properties for the given graph that are in this model’s implemented_properties list.

Parameters:

graph (AtomicGraph) – The graph representation of the structure/s.

Returns:

A dictionary mapping each implemented property to a tensor of predictions (see above for the expected shapes). Use is_batch() to check if the graph is batched in the forward pass.

Return type:

dict[PropertyKey, torch.Tensor]

predict(graph, properties)[source]

Generate (optionally batched) predictions for the given properties and graph.

This method returns a dictionary mapping each requested property to a tensor of predictions, relying on the model’s forward() implementation together with torch.autograd.grad() to automatically infer any missing properties.

Parameters:
  • graph (AtomicGraph) – The graph representation of the structure/s.

  • properties (list[Literal['local_energies', 'forces', 'energy', 'stress', 'virial']]) – The properties to predict. Can be any combination of "energy", "forces", "stress", "virial", and "local_energies".

Return type:

dict[Literal[‘local_energies’, ‘forces’, ‘energy’, ‘stress’, ‘virial’], ~torch.Tensor]

pre_fit_all_components(graphs)[source]

Pre-fit the model, and all its components, to the training data.

Some models require pre-fitting to the training data to set certain parameters. For example, the LennardJones model uses the distribution of interatomic distances in the training data to set the length-scale parameter.

In the graph-pes-train routine, this method is called before “normal” training begins (you can turn this off with a config option).

This method does two things:

  1. iterates over all the model’s Module components (inlcuding itself) and calls their pre_fit() method (if it exists - see for instance LearnableOffset for an example of a model-specific pre-fit method, and LocalEnergiesScaler for an example of a component-specific pre-fit method).

  2. registers all the unique atomic numbers in the training data with all of the model’s PerElementParameter instances to ensure correct parameter counting.

If the model has already been pre-fitted, subsequent calls to pre_fit_all_components() will be ignored (and a warning will be raised).

Parameters:

graphs (Sequence[AtomicGraph]) – The training data.

pre_fit(graphs)[source]

Override this method to perform additional pre-fitting steps.

See LocalEnergiesScaler or EnergyOffset for examples of this.

Parameters:

graphs (AtomicGraph) – The training data.

non_decayable_parameters()[source]

Return a list of parameters that should not be decayed during training.

By default, this method recurses over all available sub-modules and calls their non_decayable_parameters() (if it is defined).

See LocalEnergiesScaler for an example of this.

Return type:

list[Parameter]

get_all_PES_predictions(graph)[source]

Get all the properties that the model can predict for the given graph.

Return type:

dict[Literal[‘local_energies’, ‘forces’, ‘energy’, ‘stress’, ‘virial’], ~torch.Tensor]

predict_energy(graph)[source]

Convenience method to predict just the energy.

Return type:

Tensor

predict_forces(graph)[source]

Convenience method to predict just the forces.

Return type:

Tensor

predict_stress(graph)[source]

Convenience method to predict just the stress.

Return type:

Tensor

predict_virial(graph)[source]

Convenience method to predict just the virial.

Return type:

Tensor

predict_local_energies(graph)[source]

Convenience method to predict just the local energies.

Return type:

Tensor

property elements_seen: list[str]

The elements that the model has seen during training.

final get_extra_state()[source]

Get the extra state of this instance. Please override the extra_state() property to add extra state here.

Return type:

dict[str, Any]

final set_extra_state(state)[source]

Set the extra state of this instance using a dictionary mapping strings to values returned by the extra_state() property setter to add extra state here.

property extra_state: Any

Override this property to add extra state to the model’s state_dict.

ase_calculator(device=None, skin=1.0)[source]

Return an ASE calculator wrapping this model. See GraphPESCalculator for more information.

Parameters:
  • device (torch.device | str | None) – The device to use for the calculator. If None, the device of the model will be used.

  • skin (float) – The skin to use for the neighbour list. If all atoms have moved less than half of this distance between calls to calculate, the neighbour list will be reused, saving (in some cases) significant computation time.

Return type:

GraphPESCalculator

torch_sim_model(
device=None,
dtype=torch.float64,
*,
compute_forces=True,
compute_stress=True,
)[source]

Return a model suitable for use with the torch_sim package.

Internally, we set this model to evaluation mode, and wrap it in a class that is suitable for use with the torch_sim package.

Parameters:
  • device (torch.device | None) – The device to use for the model. If None, the model will be placed on the best device available.

  • dtype (torch.dtype) – The dtype to use for the model.

  • compute_forces (bool) – Whether to compute forces. Set this to False if you only need to generate energies within the torch_sim integrator.

  • compute_stress (bool) – Whether to compute stress. Set this to False if you don’t need stress information from the model within the torch_sim integrator.

Loading Models

graph_pes.models.load_model(path)[source]

Load a model from a file.

Parameters:

path (str | pathlib.Path) – The path to the file.

Returns:

The model.

Return type:

GraphPESModel

Examples

Use this function to load an existing model for further training using graph-pes-train:

model:
    +load_model:
        path: path/to/model.pt

See fine-tuning for more details.

graph_pes.models.load_model_component(path, key)[source]

Load a component from an AdditionModel.

Parameters:
Returns:

The component.

Return type:

GraphPESModel

Freezing Models

class graph_pes.models.T

Type alias for TypeVar("T", bound=torch.nn.Module).

graph_pes.models.freeze(model)[source]

Freeze all parameters in a module.

Parameters:

model (T) – The model to freeze.

Returns:

The model.

Return type:

T

graph_pes.models.freeze_matching(model, pattern)[source]

Freeze all parameters that match the given pattern.

Parameters:
  • model (T) – The model to freeze.

  • pattern (str) – The regular expression to match the names of the parameters to freeze.

Returns:

The model.

Return type:

T

Examples

Freeze all the parameters in the first layer of a MACE-MP0 model from mace_mp() (which have names of the form "model.interactions.0.<name>"):

model:
    +freeze_any_matching:
        model: +mace_mp()
        pattern: model\.interactions\.0\..*
graph_pes.models.freeze_all_except(model, pattern)[source]

Freeze all parameters in a model except those matching a given pattern.

Parameters:
  • model (T) – The model to freeze.

  • pattern (str | list[str]) – The pattern/s to match.

Returns:

The model.

Return type:

T

Examples

Freeze all parameters in a MACE-MP0 model from mace_mp() except those in the read-out heads:

model:
    +freeze_all_except:
        model: +mace_mp()
        pattern: model\.readouts.*
graph_pes.models.freeze_any_matching(model, patterns)[source]

Freeze all parameters that match any of the given patterns.

Parameters:
  • model (T) – The model to freeze.

  • patterns (list[str]) – The patterns to match.

Returns:

The model.

Return type:

T

Available Models

Unit Conversion

class graph_pes.models.UnitConverter(model, energy_to_eV, length_to_A)[source]

Bases: GraphPESModel

A wrapper that converts the units of the energy, forces and stress predictions of an underlying model.

Parameters:
  • model (GraphPESModel) – The underlying model.

  • energy_to_eV (float) – The conversion factor for energy, such that the model.predict_energy(graph) * energy_to_eV gives the energy prediction in eV.

  • length_to_A (float) – The conversion factor for length, such that the model.predict_forces(graph) * (energy_to_eV / length_to_A) gives the force prediction in eV/Å.