The basics

Under the hood, the graph-pes-train command performs the following steps:

  1. loads in your data, model, loss function, etc. This happens before anything else so that if you run into errors, you can quickly identify the source of the problem.

  2. “pre-fits” the model on the training data. (optional) Under-the-hood, any torch.nn.Module components of your model that define a pre_fit method will be passed the training data for them to make any adjustments/calculations before training commences (see pre_fit_all_components() for details). This is useful for e.g. estimating energy scales and offsets from the training data.

  3. trains the model using a PyTorch Lightning trainer.

  4. saves the best model for later use, as well as “deploying it” for use in LAMMPS

  5. tests the best model on the training and validation data, together with any other test data you have specified.

To control the behaviour of these steps, you need to pass graph-pes-train a nested dictionary of configuration options. These options are sourced from three places:

  1. the default values defined in training-defaults.yaml

  2. values you define in the config file/s you pass to graph-pes-train: <config-1.yaml> <config-2.yaml> ...

  3. additional command line arguments you pass to graph-pes-train: <nested/key=value> <nested/key=value> ...

The following .yaml config file contains the bare minimum information you need to specify in order to train a model:

# train a SchNet model...
        layers: 3
        channels: 64
        cutoff: =/CUTOFF

# ...using some of the QM7 structures...
        id: QM7
        cutoff: =/CUTOFF
        n_train: 5_000
        n_valid: 100

# on energy labels...
loss: +PerAtomEnergyLoss()

# ...using a cutoff of 5.0 Å
#    (referenced above)

To use this config file, while overriding the default CUTOFF value, you would run:

graph-pes-train minimal.yaml CUTOFF=3.5
the order of arguments matters!

The final nested configuration dictionary used by graph-pes-train starts as the default values. Reading from left to right, the values you pass to graph-pes-train will override these defaults.

Hence above, the CUTOFF value is set to 3.5 in the final configuration dictionary, overriding the value of 5.0 specified in minimal.yaml.

If instead you used: graph-pes-train CUTOFF=3.5 minimal.yaml then the CUTOFF value would be set to 3.5 in an intermediate step, but would ultimately be overridden by the 5.0 in minimal.yaml.

a note on syntax

You may have noticed two special syntaxes in the config file above: =/CUTOFF and +SchNet/+load_atoms_dataset.

Under-the-hood, graph-pes-train turns the final nested config dictionary into a TrainingConfig object via a 3 step process:

  1. all reference strings (of the form =/absolute/path/to/object or =../relative/path/to/object) are replaced with the corresponding value. For example:

    a: {b: "=/c"}  # absolute reference
    c: 2
    d: {e: "=../c"}  # relative reference

    will be transformed into:

    a: {b: 2}
    c: 2
    d: {e: 2}
  2. all “special” keys starting with a + are interpreted as references to Python objects. These are imported and replaced with the actual object they point to using the data2objects package - see there for more details. You can use this format to point to any Python object, defined in graph-pes or otherwise! To point to your own custom classes/objects/functions, use the fully-qualified name, e.g. +my_module.MyModelClass. By default, graph-pes-train will look for objects within the following modules:


    and hence +SchNet is shorthand for +graph_pes.models.SchNet. Ending the name with () will call the function/class constructor with no arguments. Pointing the key to a nested dictionary will pass those values as keyword arguments to the constructor. Hence, above, +PerAtomEnergyLoss() will resolve to a object, while the SchNet model will be constructed with the keyword arguments specified in the config.

  3. the resulting dictionary of python objects is then converted, using dacite, into a final nested TrainingConfig object.