Callbacks

We have implemented a few useful PyTorch Lightning callbacks that you can use to monitor your training process:

class graph_pes.training.callbacks.WandbLogger(output_dir, log_epoch, **kwargs)[source]

A subclass of WandbLogger that automatically sets the id and save_dir.

class graph_pes.training.callbacks.OffsetLogger[source]

Log any learned, per-element offsets of the model at the end of each validation epoch.

class graph_pes.training.callbacks.ScalesLogger[source]

Log any learned, per-element scaling factors of the model at the end of each validation epoch.

class graph_pes.training.callbacks.DumpModel(every_n_val_checks=10)[source]

Dump the model to <output_dir>/dumps/model_{epoch}.pt at regular intervals.

Parameters:

every_n_val_checks (int) – The number of validation epochs between dumps.

class graph_pes.training.callbacks.ModelTimer[source]

Base class

class graph_pes.training.callbacks.GraphPESCallback[source]

A base class for all callbacks that require access to useful information generated by the graph-pes-train command.