Config options

graph-pes-train is configured using a nested dictionary of options. The top-level keys that we look for are: model, data, loss, fitting, general and wandb.

You are free to add any additional top-level keys to your config files for your own purposes. This can be useful for easily referencing constants or repeated values using the = reference syntax.

# define a constant...
CUTOFF: 10.0

# ... and reference it later
model:
    +SchNet:
        cutoff: =/CUTOFF

You will also notice the + syntax used throughout. Under-the-hood, we use the data2objects library to parse these config files, and this syntax is used to automatically instantiate objects.

You can use this syntax to reference arbitrary python functions, classes and objects:

# call your own functions/class constructors
# with the ``+`` syntax and key word arguments
key:
    +my_module.my_function:
        foo: 1
        bar: 2

# syntactic sugar for calling a function
# with no arguments
key: +torch.nn.ReLU()

# reference arbitrary objects
# (note the lack of any key word arguments or parentheses)
key: +my_module.my_object

By default, we will look for any objects in the graph_pes namespace, and hence +SchNet is shorthand for graph_pes.models.SchNet etc.

model

To specify the model to train, you need to point to something that instantiates a GraphPESModel:

# point to the in-built Lennard-Jones model
model:
    +LennardJones:
        sigma: 0.1
        epsilon: 1.0

# or point to a custom model
model: +my_model.SpecialModel()

…or pass a dictionary mapping custom names to GraphPESModel objects:

model:
    offset:
        +FixedOffset: { H: -123.4, C: -456.7 }
    many-body: +SchNet()

The latter approach will be used to instantiate an AdditionModel, in this case with FixedOffset and SchNet components. This is a useful approach for dealing with arbitrary offset energies.

You can fine-tune an existing model by pointing graph-pes-train to an existing model:

model:
    +load_model:
        path: path/to/model.pt

You could also load in parts of a model if e.g. you are fine-tuning on a different level of theory with different offsets:

model:
    offset: +LearnableOffset()
    force_field:
        +load_model_component:
            path: path/to/model.pt

See the fine-tuning guide, load_model(), and load_model_component() for more details.

data

To specify the data you wish to use, you need to point to point to a dictionary that maps the keys "train" and "valid" to GraphDataset instances. A common way to do this is by using the file_dataset() function:

data:
    train:
        +file_dataset:
            path: data/train.xyz
            cutoff: 5.0
            n: 1000
            shuffle: true
            seed: 42
    valid:
        +file_dataset:
            path: data/valid.xyz
            cutoff: 5.0

Alternatively, you can point to a function that returns such a dictionary:

data:
    +my_module.my_fitting_data:
        cutoff: 5.0

This is what the load_atoms_dataset() function does:

data:
    +load_atoms_dataset:
        id: QM9
        cutoff: 5.0
        n_train: 10000
        n_val: 1000
        property_map:
            energy: U0

After training is finished, the graph-pes-train command will load the best model weights and re-test the model on the training and validation data. You can also test on other datasets at this point by including a "test" key in your config file. This should either point to:

  • a GraphDataset instance (in which case testing metrics will be logged to "best_model/test/<metric_name>")

    data:
        train: ...
        valid: ...
        test:
            +file_dataset:
                path: test_data.xyz
                cutoff: 5.0
    
  • a dictionary mapping custom names to GraphDataset instances (in which case testing metrics will be logged to "best_model/<custom_name>/<metric_name>")

    data:
        train: ...
        valid: ...
        test:
            dimers:
                +file_dataset:
                    path: data/dimers.xyz
                    cutoff: 5.0
            clusters:
                +file_dataset:
                    path: data/clusters.xyz
                    cutoff: 5.0
    

loss

This config section should either point to something that instantiates a single graph_pes.training.loss.Loss object…

# basic per-atom energy loss
loss: +PerAtomEnergyLoss()

# or more fine-grained control
loss:
    +PropertyLoss:
        property: stress
        metric: MAE  # defaults to RMSE if not specified

…or specify a list of Loss instances…

loss:
    # specify a loss with several components:
    - +PerAtomEnergyLoss()  # defaults to weight 1.0
    - +PropertyLoss:
        property: forces
        metric: MSE
        weight: 10.0

…or point to your own custom loss implementation, either in isolation:

loss:
    +my.module.CustomLoss: { alpha: 0.5 }

…or in conjunction with other components:

loss:
    - +PerAtomEnergyLoss()
    - +my.module.CustomLoss: { alpha: 0.5 }

If you want to sweep over a loss component weight via the command line, you can use a dictionary mapping arbitrary strings to loss instances like so:

loss:
    energy: +PerAtomEnergyLoss()
    forces:
        +ForceRMSE:
            weight: 5.0

allowing you to run a command such as:

for weight in 0.1 0.5 1.0; do
    graph-pes-train config.yaml loss/forces/+ForceRMSE/weight=$weight
done

fitting

The fitting section of the config is used to specify various hyperparameters and behaviours of the training process.

Optimizer

Configure the optimizer to use to train the model by pointing to something that instantiates a Optimizer.

The default is:

fitting:
    optimizer:
        +Optimizer:
            name: Adam
            lr: 3e-3
            weight_decay: 0.0
            amsgrad: false

but you could also point to your own custom optimizer:

fitting:
    optimizer: +my.module.MagicOptimizer()

Learning rate scheduler

Configure the learning rate scheduler to use to train the model by pointing to something that instantiates a LRScheduler.

For instance:

fitting:
    scheduler:
        +LRScheduler:
            name: ReduceLROnPlateau
            factor: 0.5
            patience: 10

By default, no learning rate scheduler is used if you don’t specify one, or if you specify null:

fitting:
    scheduler: null

Model pre-fitting

To turn off pre-fitting of the model, override the pre_fit_model field (default is true):

fitting:
    pre_fit_model: false

To set the maximum number of graphs to use for pre-fitting, override the max_n_pre_fit field (default is 5_000). These graphs will be randomly sampled from the training data. To use all the training data, set this to null:

fitting:
    max_n_pre_fit: 1000

Early stopping

Turn on early stopping by setting the early_stopping_patience field to an integer value (by default it is null, indicating that early stopping is disabled). This will stop training when the total validation loss ("valid/loss/total") has not improved for early_stopping_patience validation checks.

fitting:
    early_stopping_patience: 10

To have more fine-grained control over early stopping, set this field to null and use the callbacks field to add an EarlyStopping Lightning callback:

fitting:
    early_stopping_patience: null
    callbacks:
        - +pytorch_lightning.callbacks.early_stopping.EarlyStopping:
            monitor: valid/loss/forces_rmse
            patience: 100
            min_delta: 0.01
            mode: min

Data loaders

Data loaders are responsible for sampling batches of data from the dataset. We use GraphDataLoader instances to do this. These inherit from the PyTorch DataLoader class, and hence you can pass any key word arguments to the underlying loader by setting the loader_kwargs field:

fitting:
    loader_kwargs:
        seed: 42
        batch_size: 32
        persistent_workers: true
        num_workers: 4

See the PyTorch documentation for details.

We reccommend using several, persistent workers, since loading data can be a bottleneck, either due to expensive read operations from disk, or due to the time taken to convert the underlying data into AtomicGraph objects (calculating neighbour lists etc.).

Caution: setting the shuffle field here will have no effect: we always shuffle the training data, and keep the validation and testing data in order.

Stochastic weight averaging

Configure stochastic weight averaging (SWA) by specifying fields from the SWAConfig class, e.g.:

fitting:
    swa:
        lr: 1e-3
        start: 0.8
        anneal_epochs: 10
class graph_pes.config.training.SWAConfig[source]

Configuration for Stochastic Weight Averaging.

Internally, this is handled by this PyTorch Lightning callback.

lr: float

The learning rate to use during the SWA phase. If not specified, the learning rate from the end of the training phase will be used.

start: int | float = 0.8

The epoch at which to start SWA. If a float, it will be interpreted as a fraction of the total number of epochs.

anneal_epochs: int = 10

The number of epochs over which to linearly anneal the learning rate to zero.

strategy: Literal['linear', 'cos'] = 'linear'

The strategy to use for annealing the learning rate.

Callbacks

PyTorch Lightning callbacks are a convenient way to add additional functionality to the training process. We implement several useful callbacks in graph_pes.training.callbacks (e.g. graph_pes.training.callbacks.OffsetLogger). Use the callbacks field to define a list of these, or any other Callback objects, that you wish to use:

fitting:
    callbacks:
        - +graph_pes.training.callbacks.OffsetLogger()
        - +my_module.my_callback: { foo: 1, bar: 2 }

PyTorch Lightning Trainer

You are free to configure the PyTorch Lightning trainer as you see fit using the trainer_kwargs field - these keyword arguments will be passed directly to the Trainer constructor. By default, we train for 100 epochs on the best device available (and disable model summaries):

fitting:
    trainer_kwargs:
        max_epochs: 100
        accelerator: auto
        enable_model_summary: false

You can use this functionality to configure any other PyTorch Lightning trainer options, including…

Gradient clipping

Use the trainer_kwargs field to configure gradient clipping, e.g.:

fitting:
    trainer_kwargs:
        gradient_clip_val: 1.0
        gradient_clip_algorithm: "norm"

Validation frequency

Use the trainer_kwargs field to configure validation frequency. For instance, to validate at 10%, 20%, 30% etc. through the training dataset:

fitting:
    trainer_kwargs:
        val_check_interval: 0.1

See the PyTorch Lightning documentation for details.

wandb

Disable weights & biases logging:

wandb: null

Otherwise, provide a dictionary of overrides to pass to lightning’s WandbLogger

wandb:
    project: my_project
    entity: my_entity
    tags: [my_tag]

general

Other miscellaneous configuration options are defined here:

Random seed

Set the global random seed for reproducibility by setting this to an integer value (by default it is 42). This is used to set the random seed for the torch, numpy and random modules.

general:
    seed: 42

Output location

The outputs from a training run (model weights, logs etc.) are stored in ./<root_dir>/<run_id> (relative to the current working directory when you run graph-pes-train). By default, we use:

general:
    root_dir: graph-pes-results
    run_id: null  # a random run ID will be generated

You are free to specify any other root directory, and any run ID. If the same run ID is specified for multiple runs, we add numbers to the end of the run ID to make it unique (i.e. my_run, my_run_1, my_run_2, etc.):

general:
    root_dir: my_results
    run_id: my_run

Logging verbosity

Set the logging verbosity for the training run by setting this to a string value (by default it is "INFO").

general:
    log_level: DEBUG

Progress bar

Set the progress bar style to use by setting this to either:

  • "rich": use the RichProgressBar implemented in PyTorch Lightning to display a progress bar. This will not be displayed in any logs.

  • "logged": prints the validation metrics to the console at the end of each validation check.

general:
    progress: logged

Torch options

Configure common PyTorch options by setting the general.torch field to a dictionary of values from the TorchConfig class, e.g.:

general:
    torch:
        dtype: float32
        float32_matmul_precision: high
class graph_pes.config.shared.TorchConfig[source]

Configuration for PyTorch.

dtype: Literal['float16', 'float32', 'float64']

The dtype to use for all model parameters and graph properties. Defaults is "float32".

float32_matmul_precision: Literal['highest', 'high', 'medium']

The precision to use internally for float32 matrix multiplications. Refer to the PyTorch documentation for details.

Defaults to "high" to favour accelerated learning over numerical exactness for matmuls.