Models

class graph_pes.GraphPESModel(cutoff, implemented_properties)[source]

Bases: Module, ABC

All models implemented in graph-pes are subclasses of GraphPESModel.

These models make predictions (via the predict() method) of the following properties:

Key

Single graph

Batch of graphs

"local_energies"

(N,)

(N,)

"energy"

()

(M,)

"forces"

(N, 3)

(N, 3)

"stress"

(3, 3)

(M, 3, 3)

"virial"

(3, 3)

(M, 3, 3)

assuming an input of an AtomicGraph representing a single structure composed of N atoms, or an AtomicGraph composed of M structures and containing a total of N atoms. (see is_batch() for more information about batching).

Implementations must override the forward() method to generate a dictionary of predictions for the given graph. As a minimum, this must include a per-atom energy contribution ("local_energies").

For any other properties not returned by the forward pass, the predict() method will automatically infer these properties from the local energies as required:

  • "energy": as the sum of the local energies per structure.

  • "forces": as the negative gradient of the energy with respect to the atomic positions.

  • "stress": as the negative gradient of the energy with respect to a symmetric expansion of the unit cell, normalised by the cell volume.

  • "virial": as -stress * volume.

For more details on how these are calculated, see Theory.

GraphPESModel objects save various peices of extra metadata to the state_dict via the get_extra_state() and set_extra_state() methods. If you want to save additional extra state to the state_dict of your model, please implement the extra_state() property and corresponding setter to ensure that you do not overwrite these extra metadata items.

Parameters:
  • cutoff (float) – The cutoff radius for the model.

  • implemented_properties (list[PropertyKey]) – The property predictions that the model implements in the forward pass. Must include at least "local_energies".

abstract forward(graph)[source]

The model’s forward pass. Generate all properties for the given graph that are in this model’s implemented_properties list.

Parameters:

graph (AtomicGraph) – The graph representation of the structure/s.

Returns:

A dictionary mapping each implemented property to a tensor of predictions (see above for the expected shapes). Use is_batch() to check if the graph is batched in the forward pass.

Return type:

dict[PropertyKey, torch.Tensor]

predict(graph, properties)[source]

Generate (optionally batched) predictions for the given properties and graph.

This method returns a dictionary mapping each requested property to a tensor of predictions, relying on the model’s forward() implementation together with torch.autograd.grad() to automatically infer any missing properties.

Parameters:
  • graph (AtomicGraph) – The graph representation of the structure/s.

  • properties (list[Literal['local_energies', 'forces', 'energy', 'stress', 'virial']]) – The properties to predict. Can be any combination of "energy", "forces", "stress", "virial", and "local_energies".

Return type:

dict[Literal[‘local_energies’, ‘forces’, ‘energy’, ‘stress’, ‘virial’], ~torch.Tensor]

pre_fit_all_components(graphs)[source]

Pre-fit the model, and all its components, to the training data.

Some models require pre-fitting to the training data to set certain parameters. For example, the LennardJones model uses the distribution of interatomic distances in the training data to set the length-scale parameter.

In the graph-pes-train routine, this method is called before “normal” training begins (you can turn this off with a config option).

This method does two things:

  1. iterates over all the model’s Module components (inlcuding itself) and calls their pre_fit() method (if it exists - see for instance LearnableOffset for an example of a model-specific pre-fit method, and LocalEnergiesScaler for an example of a component-specific pre-fit method).

  2. registers all the unique atomic numbers in the training data with all of the model’s PerElementParameter instances to ensure correct parameter counting.

If the model has already been pre-fitted, subsequent calls to pre_fit_all_components() will be ignored (and a warning will be raised).

Parameters:

graphs (Sequence[AtomicGraph]) – The training data.

pre_fit(graphs)[source]

Override this method to perform additional pre-fitting steps.

See LocalEnergiesScaler or EnergyOffset for examples of this.

Parameters:

graphs (AtomicGraph) – The training data.

non_decayable_parameters()[source]

Return a list of parameters that should not be decayed during training.

By default, this method recurses over all available sub-modules and calls their non_decayable_parameters() (if it is defined).

See LocalEnergiesScaler for an example of this.

Return type:

list[Parameter]

get_all_PES_predictions(graph)[source]

Get all the properties that the model can predict for the given graph.

Return type:

dict[Literal[‘local_energies’, ‘forces’, ‘energy’, ‘stress’, ‘virial’], ~torch.Tensor]

predict_energy(graph)[source]

Convenience method to predict just the energy.

Return type:

Tensor

predict_forces(graph)[source]

Convenience method to predict just the forces.

Return type:

Tensor

predict_stress(graph)[source]

Convenience method to predict just the stress.

Return type:

Tensor

predict_virial(graph)[source]

Convenience method to predict just the virial.

Return type:

Tensor

predict_local_energies(graph)[source]

Convenience method to predict just the local energies.

Return type:

Tensor

property elements_seen: list[str]

The elements that the model has seen during training.

final get_extra_state()[source]

Get the extra state of this instance. Please override the extra_state() property to add extra state here.

Return type:

dict[str, Any]

final set_extra_state(state)[source]

Set the extra state of this instance using a dictionary mapping strings to values returned by the extra_state() property setter to add extra state here.

property extra_state: Any

Override this property to add extra state to the model’s state_dict.

ase_calculator(device=None, skin=1.0)[source]

Return an ASE calculator wrapping this model. See GraphPESCalculator for more information.

Parameters:
  • device (torch.device | str | None) – The device to use for the calculator. If None, the device of the model will be used.

  • skin (float) – The skin to use for the neighbour list. If all atoms have moved less than half of this distance between calls to calculate, the neighbour list will be reused, saving (in some cases) significant computation time.

Return type:

GraphPESCalculator

Loading Models

graph_pes.models.load_model(path)[source]

Load a model from a file.

Parameters:

path (str | pathlib.Path) – The path to the file.

Returns:

The model.

Return type:

GraphPESModel

Examples

Use this function to load an existing model for further training using graph-pes-train:

model:
    +load_model:
        path: path/to/model.pt

To account for some new energy offset in your training data, you could do something like this: (see also load_model_component())

model:
    # add an offset to an existing model before fine-tuning
    offset: +LearnableOffset()
    many-body:
        +load_model:
            path: path/to/model.pt
graph_pes.models.load_model_component(path, key)[source]

Load a component from an AdditionModel.

Parameters:
Returns:

The component.

Return type:

GraphPESModel

Examples

Train on data with a new energy offset:

model:
    offset: +LearnableOffset()
    many-body:
        +load_model_component:
            path: path/to/model.pt
            key: many-body

Freezing Models

class graph_pes.models.T

Type alias for TypeVar("T", bound=torch.nn.Module).

graph_pes.models.freeze(model)[source]

Freeze all parameters in a module.

Parameters:

model (T) – The model to freeze.

Returns:

The model.

Return type:

T

graph_pes.models.freeze_matching(model, pattern)[source]

Freeze all parameters that match the given pattern.

Parameters:
  • model (T) – The model to freeze.

  • pattern (str) – The regular expression to match the names of the parameters to freeze.

Returns:

The model.

Return type:

T

Examples

Freeze all the parameters in the first layer of a MACE-MP0 model from mace_mp() (which have names of the form "model.interactions.0.<name>"):

model:
    +freeze_any_matching:
        model: +mace_mp()
        pattern: model\.interactions\.0\..*
graph_pes.models.freeze_all_except(model, pattern)[source]

Freeze all parameters in a model except those matching a given pattern.

Parameters:
  • model (T) – The model to freeze.

  • pattern (str | list[str]) – The pattern/s to match.

Returns:

The model.

Return type:

T

Examples

Freeze all parameters in a MACE-MP0 model from mace_mp() except those in the read-out heads:

model:
    +freeze_all_except:
        model: +mace_mp()
        pattern: model\.readouts.*
graph_pes.models.freeze_any_matching(model, patterns)[source]

Freeze all parameters that match any of the given patterns.

Parameters:
  • model (T) – The model to freeze.

  • patterns (list[str]) – The patterns to match.

Returns:

The model.

Return type:

T

Available Models