Config options¶
graph-pes-train
is configured using a nested dictionary of options.
The top-level keys that we look for are: model
, data
, loss
, fitting
, general
and wandb
.
You are free to add any additional top-level keys to your config files for your own purposes. This can be useful for easily referencing constants or repeated values using the =
reference syntax.
# define a constant...
CUTOFF: 10.0
# ... and reference it later
model:
+SchNet:
cutoff: =/CUTOFF
You will also notice the +
syntax used throughout. Under-the-hood, we use the data2objects library to parse these config files, and this syntax is used to automatically instantiate objects.
You can use this syntax to reference arbitrary python functions, classes and objects:
# call your own functions/class constructors
# with the ``+`` syntax and key word arguments
key:
+my_module.my_function:
foo: 1
bar: 2
# syntactic sugar for calling a function
# with no arguments
key: +torch.nn.ReLU()
# reference arbitrary objects
# (note the lack of any key word arguments or parentheses)
key: +my_module.my_object
By default, we will look for any objects in the graph_pes
namespace, and hence +SchNet
is shorthand for graph_pes.models.SchNet
etc.
model
¶
To specify the model to train, you need to point to something that instantiates a GraphPESModel
:
# point to the in-built Lennard-Jones model
model:
+LennardJones:
sigma: 0.1
epsilon: 1.0
# or point to a custom model
model: +my_model.SpecialModel()
…or pass a dictionary mapping custom names to GraphPESModel
objects:
model:
offset:
+FixedOffset: { H: -123.4, C: -456.7 }
many-body: +SchNet()
The latter approach will be used to instantiate an AdditionModel
, in this case with FixedOffset
and
SchNet
components. This is a useful approach
for dealing with arbitrary offset energies.
You can fine-tune an existing model by pointing graph-pes-train
to an existing model:
model:
+load_model:
path: path/to/model.pt
You could also load in parts of a model if e.g. you are fine-tuning on a different level of theory with different offsets:
model:
offset: +LearnableOffset()
force_field:
+load_model_component:
path: path/to/model.pt
See the fine-tuning guide,
load_model()
, and load_model_component()
for more details.
data
¶
To specify the data you wish to use, you need to point to point to a dictionary that maps the keys "train"
and "valid"
to GraphDataset
instances. A common way to do this is by using the file_dataset()
function:
data:
train:
+file_dataset:
path: data/train.xyz
cutoff: 5.0
n: 1000
shuffle: true
seed: 42
valid:
+file_dataset:
path: data/valid.xyz
cutoff: 5.0
Alternatively, you can point to a function that returns such a dictionary:
data:
+my_module.my_fitting_data:
cutoff: 5.0
This is what the load_atoms_dataset()
function does:
data:
+load_atoms_dataset:
id: QM9
cutoff: 5.0
n_train: 10000
n_val: 1000
property_map:
energy: U0
After training is finished, the graph-pes-train
command will load the best model weights and re-test the model on the training and validation data.
You can also test on other datasets at this point by including a "test"
key in your config file. This should either point to:
a
GraphDataset
instance (in which case testing metrics will be logged to"best_model/test/<metric_name>"
)data: train: ... valid: ... test: +file_dataset: path: test_data.xyz cutoff: 5.0
a dictionary mapping custom names to
GraphDataset
instances (in which case testing metrics will be logged to"best_model/<custom_name>/<metric_name>"
)data: train: ... valid: ... test: dimers: +file_dataset: path: data/dimers.xyz cutoff: 5.0 clusters: +file_dataset: path: data/clusters.xyz cutoff: 5.0
loss
¶
This config section should either point to something that instantiates a single
graph_pes.training.loss.Loss
object…
# basic per-atom energy loss
loss: +PerAtomEnergyLoss()
# or more fine-grained control
loss:
+PropertyLoss:
property: stress
metric: MAE # defaults to RMSE if not specified
…or specify a list of Loss
instances…
loss:
# specify a loss with several components:
- +PerAtomEnergyLoss() # defaults to weight 1.0
- +PropertyLoss:
property: forces
metric: MSE
weight: 10.0
…or point to your own custom loss implementation, either in isolation:
loss:
+my.module.CustomLoss: { alpha: 0.5 }
…or in conjunction with other components:
loss:
- +PerAtomEnergyLoss()
- +my.module.CustomLoss: { alpha: 0.5 }
If you want to sweep over a loss component weight via the command line, you can use a dictionary mapping arbitrary strings to loss instances like so:
loss:
energy: +PerAtomEnergyLoss()
forces:
+ForceRMSE:
weight: 5.0
allowing you to run a command such as:
for weight in 0.1 0.5 1.0; do
graph-pes-train config.yaml loss/forces/+ForceRMSE/weight=$weight
done
fitting
¶
The fitting
section of the config is used to specify various hyperparameters and behaviours of the training process.
Optimizer¶
Configure the optimizer to use to train the model by pointing to something that instantiates a Optimizer
.
The default is:
fitting:
optimizer:
+Optimizer:
name: Adam
lr: 3e-3
weight_decay: 0.0
amsgrad: false
but you could also point to your own custom optimizer:
fitting:
optimizer: +my.module.MagicOptimizer()
Learning rate scheduler¶
Configure the learning rate scheduler to use to train the model by pointing to something that instantiates a LRScheduler
.
For instance:
fitting:
scheduler:
+LRScheduler:
name: ReduceLROnPlateau
factor: 0.5
patience: 10
By default, no learning rate scheduler is used if you don’t specify one, or if you specify null
:
fitting:
scheduler: null
Model pre-fitting¶
To turn off pre-fitting of the model, override the pre_fit_model
field (default is true
):
fitting:
pre_fit_model: false
To set the maximum number of graphs to use for pre-fitting, override the max_n_pre_fit
field (default is 5_000
). These graphs will be randomly sampled from the training data. To use all the training data, set this to null
:
fitting:
max_n_pre_fit: 1000
Early stopping¶
Turn on early stopping by setting the early_stopping_patience
field to an integer value (by default it is null
, indicating that early stopping is disabled). This will stop training when the total validation loss ("valid/loss/total"
) has not improved for early_stopping_patience
validation checks.
fitting:
early_stopping_patience: 10
To have more fine-grained control over early stopping, set this field to null
and use the callbacks
field to add an EarlyStopping Lightning callback:
fitting:
early_stopping_patience: null
callbacks:
- +pytorch_lightning.callbacks.early_stopping.EarlyStopping:
monitor: valid/loss/forces_rmse
patience: 100
min_delta: 0.01
mode: min
Data loaders¶
Data loaders are responsible for sampling batches of data from the dataset. We use GraphDataLoader
instances to do this. These inherit from the PyTorch DataLoader
class, and hence you can pass any key word arguments to the underlying loader by setting the loader_kwargs
field:
fitting:
loader_kwargs:
seed: 42
batch_size: 32
persistent_workers: true
num_workers: 4
See the PyTorch documentation for details.
We reccommend using several, persistent workers, since loading data can be a bottleneck, either due to expensive read operations from disk, or due to the time taken to convert the underlying data into AtomicGraph
objects (calculating neighbour lists etc.).
Caution: setting the shuffle
field here will have no effect: we always shuffle the training data, and keep the validation and testing data in order.
Stochastic weight averaging¶
Configure stochastic weight averaging (SWA) by specifying fields from the SWAConfig
class, e.g.:
fitting:
swa:
lr: 1e-3
start: 0.8
anneal_epochs: 10
- class graph_pes.config.training.SWAConfig[source]¶
Configuration for Stochastic Weight Averaging.
Internally, this is handled by this PyTorch Lightning callback.
- lr: float¶
The learning rate to use during the SWA phase. If not specified, the learning rate from the end of the training phase will be used.
- start: int | float = 0.8¶
The epoch at which to start SWA. If a float, it will be interpreted as a fraction of the total number of epochs.
Callbacks¶
PyTorch Lightning callbacks are a convenient way to add additional functionality to the training process.
We implement several useful callbacks in graph_pes.training.callbacks
(e.g. graph_pes.training.callbacks.OffsetLogger
). Use the callbacks
field to define a list of these, or any other Callback
objects, that you wish to use:
fitting:
callbacks:
- +graph_pes.training.callbacks.OffsetLogger()
- +my_module.my_callback: { foo: 1, bar: 2 }
PyTorch Lightning Trainer¶
You are free to configure the PyTorch Lightning trainer as you see fit using the trainer_kwargs
field - these keyword arguments will be passed directly to the Trainer
constructor. By default, we train for 100 epochs on the best device available (and disable model summaries):
fitting:
trainer_kwargs:
max_epochs: 100
accelerator: auto
enable_model_summary: false
You can use this functionality to configure any other PyTorch Lightning trainer options, including…
Gradient clipping¶
Use the trainer_kwargs
field to configure gradient clipping, e.g.:
fitting:
trainer_kwargs:
gradient_clip_val: 1.0
gradient_clip_algorithm: "norm"
Validation frequency¶
Use the trainer_kwargs
field to configure validation frequency. For instance, to validate at 10%, 20%, 30% etc. through the training dataset:
fitting:
trainer_kwargs:
val_check_interval: 0.1
See the PyTorch Lightning documentation for details.
wandb
¶
Disable weights & biases logging:
wandb: null
Otherwise, provide a dictionary of overrides to pass to lightning’s WandbLogger
wandb:
project: my_project
entity: my_entity
tags: [my_tag]
general
¶
Other miscellaneous configuration options are defined here:
Random seed¶
Set the global random seed for reproducibility by setting this to an integer value (by default it is 42
). This is used to set the random seed for the torch
, numpy
and random
modules.
general:
seed: 42
Output location¶
The outputs from a training run (model weights, logs etc.) are stored in ./<root_dir>/<run_id>
(relative to the current working directory when you run graph-pes-train
). By default, we use:
general:
root_dir: graph-pes-results
run_id: null # a random run ID will be generated
You are free to specify any other root directory, and any run ID. If the same run ID is specified for multiple runs, we add numbers to the end of the run ID to make it unique (i.e. my_run
, my_run_1
, my_run_2
, etc.):
general:
root_dir: my_results
run_id: my_run
Logging verbosity¶
Set the logging verbosity for the training run by setting this to a string value (by default it is "INFO"
).
general:
log_level: DEBUG
Progress bar¶
Set the progress bar style to use by setting this to either:
"rich"
: use the RichProgressBar implemented in PyTorch Lightning to display a progress bar. This will not be displayed in any logs."logged"
: prints the validation metrics to the console at the end of each validation check.
general:
progress: logged
Torch options¶
Configure common PyTorch options by setting the general.torch
field to a dictionary of values from the TorchConfig
class, e.g.:
general:
torch:
dtype: float32
float32_matmul_precision: high
Configuration for PyTorch.
The dtype to use for all model parameters and graph properties. Defaults is
"float32"
.
The precision to use internally for float32 matrix multiplications. Refer to the PyTorch documentation for details.
Defaults to
"high"
to favour accelerated learning over numerical exactness for matmuls.