Train a model¶
FYI, you can open this documentation as a Google Colab notebook to follow along interactively
graph-pes-train provides a unified interface to train any GraphPESModel, including those packaged within graph_pes.models and those defined by you, the user.
For more information on the graph-pes-train
command, and the plethora of options available for specification in your config.yaml
see the CLI reference.
Below, we train a lightweight NequIP model on the C-GAP-17 dataset.
Installation¶
[1]:
!pip install graph-pes
Successfully installed graph-pes-0.0.7
We now should have access to the graph-pes-train
command. We can check this by running:
[1]:
!graph-pes-train -h
usage: graph-pes-train [-h] [args ...]
Train a GraphPES model using PyTorch Lightning.
positional arguments:
args Config files and command line specifications. Config files
should be YAML (.yaml/.yml) files. Command line specifications
should be in the form my/nested/key=value. Final config is built
up from these items in a left to right manner, with later items
taking precedence over earlier ones in the case of conflicts.
The data2objects package is used to resolve references and
create objects directly from the config dictionary.
options:
-h, --help show this help message and exit
Copyright 2023-24, John Gardner
Data definition¶
We use load-atoms to download and split the C-GAP-17 dataset into training, validation and test datasets:
[2]:
import ase.io
from load_atoms import load_dataset
structures = load_dataset("C-GAP-17")
train, val, test = structures.random_split([0.8, 0.1, 0.1])
ase.io.write("train-cgap17.xyz", train)
ase.io.write("val-cgap17.xyz", val)
ase.io.write("test-cgap17.xyz", test)
We can visualise the kinds of structures we’re training on using load_atoms.view:
[3]:
from load_atoms import view
view(train[0], show_bonds=True)
[3]:
Configuration¶
Great - now lets train a model. To do this, we have specified the following in our quickstart-cgap17.yaml
file:
the model architecture to instantiate and train, here NequIP. Note that we also include a FixedOffset component to account for the fact that the C-GAP-17 labels have an arbitrary offset.
the data to train on, here the C-GAP-17 dataset we just downloaded
the loss function to use, here a combination of a per-atom energy loss and a per-atom force loss
and various other training hyperparameters (e.g. the learning rate, batch size, etc.)
quickstart-cgap17.yaml
general:
run_id: quickstart-cgap17
progress: logged
# train a lightweight NequIP model ...
model:
offset:
# note the "+" prefix syntax: refer to the
# data2objects package for more details
+FixedOffset: { C: -148.314002 }
many-body:
+NequIP:
elements: [C]
cutoff: 3.7 # radial cutoff in Å
layers: 2
features:
channels: [16, 8, 4]
l_max: 2
use_odd_parity: true
self_interaction: linear
# ... on structures from local files ...
data:
train:
+file_dataset:
path: train-cgap17.xyz
cutoff: 3.7
n: 1280
shuffle: false
valid:
+file_dataset:
path: val-cgap17.xyz
cutoff: 3.7
# ... on both energy and forces (weighted 1:1) ...
loss:
- +PerAtomEnergyLoss()
- +PropertyLoss: { property: forces, metric: RMSE }
# ... with the following settings ...
fitting:
trainer_kwargs:
max_epochs: 250
accelerator: auto
check_val_every_n_epoch: 5
optimizer:
+Optimizer:
name: AdamW
lr: 0.01
scheduler:
+LRScheduler:
name: ReduceLROnPlateau
factor: 0.5
patience: 10
loader_kwargs:
batch_size: 64
# ... and log to Weights & Biases
wandb:
project: graph-pes-quickstart
We can download this config file using wget:
[4]:
%%bash
if [ ! -f quickstart-cgap17.yaml ]; then
wget https://tinyurl.com/graph-pes-quickstart-cgap17 -O quickstart-cgap17.yaml
fi
Training¶
You can see the output of the training run below in this Weights and Biases dashboard.
[6]:
!graph-pes-train quickstart-cgap17.yaml
[graph-pes INFO]: Started `graph-pes-train` at 2024-12-05 08:09:05.603
[graph-pes INFO]: Successfully parsed config.
[graph-pes INFO]: ID for this training run: quickstart-cgap17
[graph-pes INFO]:
Output for this training run can be found at:
└─ graph-pes-results/quickstart-cgap17
├─ logs/rank-0.log # find a verbose log here
├─ model.pt # the best model
├─ lammps_model.pt # the best model deployed to LAMMPS
└─ train-config.yaml # the complete config used for this run
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: jla-gardner. Use `wandb login --relogin` to force relogin
wandb: Run data is saved locally in graph-pes-results/wandb/run-20241205_080908-quickstart-cgap17
wandb: Run `wandb offline` to turn off syncing.
wandb: Starting run quickstart-cgap17
wandb: ⭐️ View project at https://wandb.ai/jla-gardner/graph-pes-quickstart
wandb: 🚀 View run at https://wandb.ai/jla-gardner/graph-pes-quickstart/runs/quickstart-cgap17
[graph-pes INFO]: Logging to graph-pes-results/quickstart-cgap17/logs/rank-0.log
[graph-pes INFO]: Preparing data
[graph-pes INFO]: Setting up datasets
[graph-pes INFO]: Pre-fitting the model on 1,280 samples
[graph-pes INFO]:
Number of learnable params:
offset (FixedOffset): 0
many-body (NequIP) : 4,233
[graph-pes INFO]: Sanity checking the model...
[graph-pes INFO]: Starting fit...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
valid/metrics valid/metrics valid/metrics timer/its_per_s timer/its_per_s
epoch time per_atom_energy_rmse forces_rmse energy_rmse train valid
5 9.5 0.28588 1.63375 17.92273 21.73913 40.94640
10 18.7 0.20442 1.03626 14.99166 18.86792 41.89713
15 28.0 0.09091 0.95953 4.80609 18.86792 41.28917
20 37.2 0.12216 0.91683 7.97960 17.85714 41.64530
25 46.5 0.12741 0.88628 8.89986 20.40816 41.46724
30 56.0 0.10043 0.87486 5.06527 20.00000 41.30190
35 65.4 0.07815 0.86406 4.22427 18.86792 41.46724
40 74.6 0.18409 0.85244 13.32946 19.23077 42.25326
45 83.6 0.12486 0.84220 8.04440 18.18182 42.24054
50 92.5 0.07419 0.83882 3.85131 18.86792 42.06248
55 101.8 0.10297 0.83181 6.28646 18.51852 42.03170
60 110.7 0.07544 0.82551 3.95342 17.85714 42.22400
65 120.3 0.07011 0.82094 3.56080 18.51852 41.47995
70 129.7 0.07098 0.82106 3.78168 16.39344 40.60300
75 138.7 0.06978 0.81274 3.71009 18.51852 41.46724
80 148.0 0.11663 0.82113 7.93304 18.18182 41.49420
85 157.0 0.06806 0.81087 3.54228 18.86792 41.65954
90 166.3 0.10854 0.82078 7.23582 18.18182 41.88441
95 175.3 0.09466 0.80274 6.01067 17.85714 41.85363
100 184.2 0.08501 0.79980 5.04500 20.40816 41.30189
105 193.1 0.07817 0.80133 4.89728 19.60784 41.64530
110 202.0 0.11911 0.80314 7.73768 18.51852 42.81924
115 211.0 0.09696 0.79252 6.07346 18.18182 41.89713
120 219.9 0.10645 0.79790 6.85942 18.86792 42.46312
125 228.8 0.06499 0.79187 3.38392 18.18182 41.28917
130 238.1 0.09359 0.78877 6.07430 20.00000 42.07520
135 247.0 0.06430 0.78919 3.63812 18.51852 42.09275
140 256.3 0.06514 0.79163 3.47451 19.23077 41.13655
145 265.2 0.07644 0.78685 4.21608 20.40816 42.06248
150 274.2 0.07534 0.78641 4.64563 17.85714 41.65802
155 283.1 0.09428 0.79052 6.07322 18.51852 41.67557
160 292.1 0.10394 0.78160 6.46291 19.60784 41.65954
165 301.0 0.06789 0.79047 3.91855 18.86792 41.46724
170 310.0 0.07481 0.77929 4.38113 17.54386 41.67557
175 318.9 0.05980 0.77984 3.24851 19.23077 41.46724
180 328.2 0.08898 0.78014 5.88639 18.86792 42.04594
185 337.1 0.12437 0.78408 8.49052 18.86792 41.86788
190 346.1 0.09237 0.78591 5.70621 18.18182 41.82336
195 355.1 0.06462 0.77582 3.55072 17.85714 39.97809
200 364.1 0.10566 0.77532 6.90219 18.51852 41.86788
205 373.1 0.11203 0.80174 7.09306 21.27660 42.62694
210 382.0 0.08404 0.77830 5.35065 20.83333 41.88441
215 391.0 0.06217 0.77353 3.57670 18.86792 41.49751
220 400.3 0.06551 0.77113 3.60655 18.51852 42.43285
225 409.2 0.11056 0.77826 7.54414 18.86792 41.88441
230 418.2 0.06231 0.77460 3.72732 18.18182 41.64530
235 427.1 0.07255 0.77408 4.69319 20.00000 41.46724
240 436.0 0.08572 0.76910 5.88647 17.85714 41.83760
245 445.0 0.06867 0.76606 4.02220 18.86792 41.65802
250 454.2 0.06577 0.76767 3.78755 18.51852 41.64530
`Trainer.fit` stopped: `max_epochs=250` reached.
[graph-pes INFO]: Loading best weights from "graph-pes-results/quickstart-cgap17/checkpoints/best.ckpt"
[graph-pes INFO]: Training complete. Awaiting final Lightning and W&B shutdown...
wandb: \ 0.015 MB of 0.015 MB uploaded
wandb: Run history:
wandb: epoch ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb: lr-AdamW/non-decayable ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: lr-AdamW/normal ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: n_learnable_parameters ▁
wandb: n_parameters ▁
wandb: timer/its_per_s/train █▃▅▃▄▃▅▄▄▇▄▁▄▅▄▅▄▄▄▄▄▄▅▅▄▆▅▄▅▅▄▆█▇▄▅▄▆▄▄
wandb: timer/its_per_s/valid ▃▆▄▅▄▅▇▇▆▇▅▃▅▅▆▆▅█▆▇▆▆▄▆▅▅▅▅▆▆▆▁█▆▅▇▅▅▆▅
wandb: timer/step_duration_ms/train ▁▆▄▅▅▅▃▅▅▂▅█▅▄▅▃▅▅▅▅▅▅▃▃▅▃▄▅▄▄▅▃▁▁▅▄▅▃▅▅
wandb: timer/step_duration_ms/valid ▅▅▅▄▅▅▃▃▂▂▅▇▅▄▅▃▄▁▅▂▄▄▆▄▄▄▅▄▂▃▃█▂▅▅▂▄▅▃▄
wandb: train/loss/forces_rmse_weighted █▃▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: train/loss/per_atom_energy_rmse_weighted █▅▃▂▄▃▂▂▃▂▂▂▂▁▂▁▁▁▁▁▁▂▁▂▂▂▁▂▂▂▂▁▂▁▁▁▁▂▂▂
wandb: train/loss/total █▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: train/metrics/forces_rmse █▃▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: train/metrics/per_atom_energy_rmse █▅▃▂▄▃▂▂▃▂▂▂▂▁▂▁▁▁▁▁▁▂▁▂▂▂▁▂▂▂▂▁▂▁▁▁▁▂▂▂
wandb: trainer/global_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb: valid/loss/forces_rmse_weighted █▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: valid/loss/per_atom_energy_rmse_weighted █▅▂▃▂▁▅▃▂▁▁▁▃▁▂▂▂▃▂▂▂▁▁▁▂▂▁▁▂▃▂▁▃▂▁▁▁▁▂▁
wandb: valid/loss/total █▄▂▂▂▂▂▂▂▁▁▁▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁
wandb: valid/metrics/energy_rmse █▇▂▃▂▁▆▃▂▁▁▁▃▁▃▂▂▃▂▃▂▁▁▁▂▂▁▁▂▃▂▁▃▂▁▁▁▂▂▁
wandb: valid/metrics/forces_rmse █▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: valid/metrics/per_atom_energy_rmse █▅▂▃▂▁▅▃▂▁▁▁▃▁▂▂▂▃▂▂▂▁▁▁▂▂▁▁▂▃▂▁▃▂▁▁▁▁▂▁
wandb:
wandb: Run summary:
wandb: epoch 249
wandb: lr-AdamW/non-decayable 0.01
wandb: lr-AdamW/normal 0.01
wandb: n_learnable_parameters 4233
wandb: n_parameters 4234
wandb: timer/its_per_s/train 18.51852
wandb: timer/its_per_s/valid 41.6453
wandb: timer/step_duration_ms/train 54.0
wandb: timer/step_duration_ms/valid 24.875
wandb: train/loss/forces_rmse_weighted 0.78016
wandb: train/loss/per_atom_energy_rmse_weighted 0.08515
wandb: train/loss/total 0.86531
wandb: train/metrics/forces_rmse 0.78016
wandb: train/metrics/per_atom_energy_rmse 0.08515
wandb: trainer/global_step 4999
wandb: valid/loss/forces_rmse_weighted 0.76767
wandb: valid/loss/per_atom_energy_rmse_weighted 0.06577
wandb: valid/loss/total 0.83344
wandb: valid/metrics/energy_rmse 3.78755
wandb: valid/metrics/forces_rmse 0.76767
wandb: valid/metrics/per_atom_energy_rmse 0.06577
wandb:
wandb: 🚀 View run quickstart-cgap17 at: https://wandb.ai/jla-gardner/graph-pes-quickstart/runs/quickstart-cgap17
wandb: ⭐️ View project at: https://wandb.ai/jla-gardner/graph-pes-quickstart
wandb: Synced 4 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: graph-pes-results/wandb/run-20241205_080908-quickstart-cgap17/logs
Model analysis¶
Let’s load the best model from the above training run and evaluate it on the test dataset:
[7]:
from graph_pes.models import load_model
best_model = load_model("graph-pes-results/quickstart-cgap17/model.pt")
best_model
[7]:
AdditionModel(
offset=FixedOffset(trainable=False),
many-body=NequIP(
(Z_embedding): AtomicOneHot(elements=['C'])
(initial_node_embedding): PerElementEmbedding(dim=16, elements=[])
(edge_embedding): SphericalHarmonics(1x1o -> 1x0e+1x1o+1x2e)
(layers): UniformModuleList(
(0): NequIPMessagePassingLayer(
(pre_message_linear): Linear(16x0e -> 16x0e | 256 weights)
(message_tensor_product): TensorProduct(16x0e x 1x0e+1x1o+1x2e -> 16x0e+16x1o+16x2e | 48 paths | 48 weights)
(weight_generator): HaddamardProduct(
(components): ModuleList(
(0): Sequential(
(0): Bessel(n_features=8, cutoff=3.700000047683716, trainable=True)
(1): MLP(8 → 8 → 8 → 48, activation=SiLU())
)
(1): PolynomialEnvelope(cutoff=3.7, p=6)
)
)
(aggregation): SumNeighbours()
(non_linearity): Gate (28x0e+8x1o+4x2e -> 16x0e+8x1o+4x2e)
(post_message_linear): Linear(16x0e+16x1o+16x2e -> 28x0e+8x1o+4x2e | 640 weights)
(self_interaction): LinearSelfInteraction(
(linear): Linear(16x0e -> 28x0e+8x1o+4x2e | 448 weights)
)
)
(1): NequIPMessagePassingLayer(
(pre_message_linear): Linear(16x0e+8x1o+4x2e -> 16x0e+8x1o+4x2e | 336 weights)
(message_tensor_product): TensorProduct(16x0e+8x1o+4x2e x 1x0e+1x1o+1x2e -> 28x0e+36x1o+12x1e+12x2o+32x2e | 120 paths | 120 weights)
(weight_generator): HaddamardProduct(
(components): ModuleList(
(0): Sequential(
(0): Bessel(n_features=8, cutoff=3.700000047683716, trainable=True)
(1): MLP(8 → 8 → 8 → 120, activation=SiLU())
)
(1): PolynomialEnvelope(cutoff=3.7, p=6)
)
)
(aggregation): SumNeighbours()
(non_linearity): Gate (16x0e -> 16x0e)
(post_message_linear): Linear(28x0e+36x1o+12x1e+12x2o+32x2e -> 16x0e | 448 weights)
(self_interaction): LinearSelfInteraction(
(linear): Linear(16x0e+8x1o+4x2e -> 16x0e | 256 weights)
)
)
)
(energy_readout): LinearReadOut(16x0e -> 1x0e | 16 weights)
(scaler): LocalEnergiesScaler(trainable=True)
)
)
GraphPESModel act on AtomicGraph objects.
We can easily convert our ase.Atoms objects into AtomicGraph objects using AtomicGraph.from_ase (we could also use the GraphPESCalculator to act directly on the ase.Atoms objects if we wanted to).
[8]:
from graph_pes.atomic_graph import AtomicGraph
test_graphs = [
AtomicGraph.from_ase(structure, cutoff=3.7) for structure in test
]
test_graphs[0]
[8]:
AtomicGraph(
atoms=64,
edges=1124,
has_cell=True,
cutoff=3.7,
properties=['energy', 'forces']
)
Our predictions look like this:
[9]:
{
k: v.shape
for k, v in best_model.get_all_PES_predictions(test_graphs[0]).items()
}
[9]:
{'energy': torch.Size([]),
'forces': torch.Size([64, 3]),
'local_energies': torch.Size([64]),
'stress': torch.Size([3, 3]),
'virial': torch.Size([3, 3])}
We can see from a single data point that our model has done a reasonable job of learning the potential:
[10]:
best_model.predict_energy(test_graphs[0]), test_graphs[0].properties["energy"]
[10]:
(tensor(-9994.0742), tensor(-9998.7080))
graph-pes
provides a few utility functions for visualising model performance:
[11]:
import matplotlib.pyplot as plt
from graph_pes.atomic_graph import divide_per_atom
from graph_pes.utils.analysis import parity_plot
%config InlineBackend.figure_format = 'retina'
parity_plot(
best_model,
test_graphs,
property="energy",
transform=divide_per_atom,
units="eV / atom",
lw=0,
s=12,
color="crimson",
)
plt.xlim(-158.5, -155)
plt.ylim(-158.5, -155);

[12]:
parity_plot(
best_model,
test_graphs,
property="forces",
units="eV / Å",
lw=0,
s=2,
alpha=0.5,
color="crimson",
)

[13]:
from graph_pes.utils.analysis import dimer_curve
dimer_curve(best_model, system="CC", units="eV", rmin=0.7, rmax=4.0);

Fine-tuning¶
Let’s now take the model we trained above, and fine-tune it on the C-GAP-20U dataset.
[14]:
import ase.io
from load_atoms import load_dataset
structures = load_dataset("C-GAP-20U")
train, val, test = structures.random_split([0.8, 0.1, 0.1])
ase.io.write("train-cgap20u.xyz", train)
ase.io.write("val-cgap20u.xyz", val)
ase.io.write("test-cgap20u.xyz", test)
We can see that the C-GAP-20U
dataset clearly has labels with a different arbitrary offset to the C-GAP-17
dataset.
[15]:
cgap20_test_graphs = [
AtomicGraph.from_ase(structure, cutoff=3.7) for structure in test
]
parity_plot(
best_model,
cgap20_test_graphs,
property="energy",
transform=divide_per_atom,
units="eV / atom",
)

In fact, the energy
labels on C-GAP-20U
are formation energies, and hence the offset we used above is no longer necessary:
[16]:
from graph_pes.models import AdditionModel
assert isinstance(best_model, AdditionModel)
underlying_nequip = best_model["many-body"]
type(underlying_nequip)
[16]:
graph_pes.models.e3nn.nequip.NequIP
[17]:
parity_plot(
underlying_nequip,
cgap20_test_graphs,
property="energy",
transform=divide_per_atom,
units="eV / atom",
lw=0,
s=12,
color="crimson",
)

Here we can see that the pre-trained NequIP model is over-predicting the formation energies: this is to be expected as the DFT codes used to label C-GAP-17
and C-GAP-20U
are different.
Training run¶
To fine-tune a GraphPESModel
, we can use the same graph-pes-train
command, but with a modified config file where we explicitly load in the model we want to fine-tune using either load_model or load_model_component.
finetune-cgap20u.yaml
general:
run_id: finetune-cgap20u
progress: logged
# finetune a pre-trained NequIP model ...
model:
+load_model_component:
path: <insert path to model>
key: many-body
# ... on structures from local files ...
data:
train:
+file_dataset:
# take the first 1280 structures from train-cgap20u.xyz
path: train-cgap20u.xyz
cutoff: 3.7
n: 1280
shuffle: false
valid:
+file_dataset:
# use all structures from val-cgap20u.xyz
path: val-cgap20u.xyz
cutoff: 3.7
# ... on both energy and forces ...
loss:
- +PerAtomEnergyLoss()
- +PropertyLoss: { property: forces, metric: RMSE }
# ... with the following settings ...
fitting:
trainer_kwargs:
max_epochs: 150
accelerator: auto
pre_fit_model: false
optimizer:
+Optimizer:
name: AdamW
lr: 0.003
scheduler:
+LRScheduler:
name: ReduceLROnPlateau
factor: 0.5
patience: 10
loader_kwargs:
batch_size: 64
# ... and log to Weights & Biases
wandb:
project: graph-pes-quickstart
[18]:
%%bash
if [ ! -f finetune-cgap20u.yaml ]; then
wget https://tinyurl.com/graph-pes-finetune-cgap20u -O finetune-cgap20u.yaml
fi
[19]:
!graph-pes-train finetune-cgap20u.yaml model/+load_model_component/path=graph-pes-results/quickstart-cgap17/model.pt # noqa: E501
[graph-pes INFO]: Started `graph-pes-train` at 2024-12-05 08:19:40.426
[graph-pes INFO]: Successfully parsed config.
[graph-pes INFO]: ID for this training run: finetune-cgap20u
[graph-pes INFO]:
Output for this training run can be found at:
└─ graph-pes-results/finetune-cgap20u
├─ logs/rank-0.log # find a verbose log here
├─ model.pt # the best model
├─ lammps_model.pt # the best model deployed to LAMMPS
└─ train-config.yaml # the complete config used for this run
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: jla-gardner. Use `wandb login --relogin` to force relogin
wandb: Run data is saved locally in graph-pes-results/wandb/run-20241205_081944-finetune-cgap20u
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run finetune-cgap20u
wandb: ⭐️ View project at https://wandb.ai/jla-gardner/graph-pes-quickstart
wandb: 🚀 View run at https://wandb.ai/jla-gardner/graph-pes-quickstart/runs/finetune-cgap20u
[graph-pes INFO]: Logging to graph-pes-results/finetune-cgap20u/logs/rank-0.log
[graph-pes INFO]: Preparing data
[graph-pes INFO]: Caching neighbour lists for 1280 structures with cutoff 3.7
[graph-pes INFO]: Caching neighbour lists for 608 structures with cutoff 3.7
[graph-pes INFO]: Setting up datasets
[graph-pes INFO]: Number of learnable params : 4,233
[graph-pes INFO]: Sanity checking the model...
[graph-pes INFO]: Starting fit...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
valid/metrics valid/metrics valid/metrics valid/metrics valid/metrics timer/its_per_s timer/its_per_s
epoch time per_atom_energy_rmse forces_rmse energy_rmse stress_rmse virial_rmse train valid
5 12.0 0.07865 0.68254 3.96113 0.03609 17.63322 23.80952 37.49900
10 22.7 0.07496 0.67690 3.82261 0.03304 17.34807 19.23077 36.97446
15 33.2 0.08033 0.67894 4.78717 0.03800 19.28491 17.54386 37.11675
20 44.0 0.07452 0.67271 3.73615 0.03089 14.90855 19.60784 37.52057
25 54.8 0.08204 0.67367 5.36235 0.03168 15.47247 19.23077 37.50641
30 65.3 0.07406 0.67675 4.30161 0.02724 12.27123 20.00000 37.43886
35 75.5 0.07156 0.67334 3.95478 0.02931 15.73284 18.18182 37.22427
40 86.4 0.07489 0.67066 3.60348 0.03342 18.05131 19.23077 37.58847
45 96.6 0.07079 0.67161 3.41603 0.03000 14.41076 18.18182 37.63869
50 107.1 0.07305 0.67295 3.56000 0.02818 13.76435 19.60784 37.49900
55 117.3 0.07023 0.67409 3.76528 0.02724 13.51196 18.18182 37.12831
60 127.8 0.07088 0.67273 4.01937 0.02963 14.29006 22.22222 37.25920
65 138.3 0.06981 0.67033 3.71973 0.02761 12.70107 18.86792 37.24326
70 148.5 0.07237 0.67212 3.86753 0.03468 18.37819 19.23077 36.85359
75 159.0 0.07002 0.67110 3.79805 0.02902 13.88428 18.51852 36.54050
80 169.6 0.06927 0.67006 3.40840 0.02902 13.61990 18.86792 36.96290
85 179.7 0.07055 0.67000 3.48576 0.02826 14.21006 19.60784 37.37414
90 189.9 0.07125 0.66895 3.87585 0.03137 16.16199 18.86792 37.37414
95 201.1 0.07019 0.66909 3.95231 0.02830 13.82447 19.23077 37.37414
100 211.2 0.07089 0.66866 4.09286 0.02903 14.54211 19.60784 37.13604
105 221.8 0.06771 0.66860 3.38274 0.02796 13.15577 19.23077 37.28501
110 232.0 0.07259 0.66760 4.15963 0.03054 15.05223 18.51852 37.11675
115 242.2 0.06952 0.66778 3.75120 0.02948 14.06543 17.54386 37.01168
120 252.8 0.06805 0.66722 3.30765 0.02783 13.14172 19.60784 37.25009
125 264.0 0.06787 0.66716 3.33523 0.02702 12.45630 17.54386 37.03919
130 274.9 0.06773 0.66712 3.33990 0.02738 12.71667 17.85714 37.01623
135 285.1 0.06798 0.66733 3.37465 0.02904 13.94043 18.18182 36.27171
140 295.7 0.06882 0.66725 3.70164 0.02855 13.71589 20.00000 37.25920
145 305.8 0.06764 0.66698 3.48667 0.02762 13.05017 19.60784 36.99604
150 316.6 0.06777 0.66686 3.49672 0.02758 12.96325 18.51852 37.25920
`Trainer.fit` stopped: `max_epochs=150` reached.
[graph-pes INFO]: Loading best weights from "graph-pes-results/finetune-cgap20u/checkpoints/best.ckpt"
[graph-pes INFO]: Training complete. Awaiting final Lightning and W&B shutdown...
wandb: \ 0.041 MB of 0.041 MB uploaded
wandb: Run history:
wandb: epoch ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb: lr-AdamW/non-decayable ███████████████▄▄▄▄▄▄▄▄▄▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
wandb: lr-AdamW/normal ███████████████▄▄▄▄▄▄▄▄▄▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
wandb: n_learnable_parameters ▁
wandb: n_parameters ▁
wandb: timer/its_per_s/train ▅█▃▂▃▄▃▃▃▂▃▂▁▄▂▂▃▃▃▂▂▃▄▂▄▃▄▂▃▂▁▃▄▁▂▃▂▄▄▂
wandb: timer/its_per_s/valid ▄▁▅▅▅▇▄▅▇▆▅▅▄▇▃▆▄▆▂▅▇▃▄▄▃▃▄▄▅▃▄▆█▆▃▇█▇▆▅
wandb: timer/step_duration_ms/train ▃▁▅▆▆▅▅▆▅▇▅▇█▅▇▇▆▆▅▆▇▆▅▇▅▅▅▇▆▆█▅▅█▇▅▆▅▅▆
wandb: timer/step_duration_ms/valid ▆█▃▄▄▂▆▄▂▃▄▃▆▂▆▃▄▃▇▄▂▇▄▄▇▇▆▄▃▆▆▃▁▃▇▂▁▂▃▄
wandb: train/loss/forces_rmse_weighted ▅▃▆▇▆▂▄█▇█▄▅▁▃▅▅▇▅▂▃▅▄█▁▄▂▅▄▃▃▄▁▄▄▃▆▃▆▅▄
wandb: train/loss/per_atom_energy_rmse_weighted ▄▄▇▂█▄▇▅▇▄▄▇▅▅▅▂▃▃▅▅▃▂▆▄▃▄▇▃▃▂▃▃▄▁▅▂▃▆▆▄
wandb: train/loss/total ▅▃▇▆▇▃▅█▇▇▄▆▂▃▅▄▆▄▂▃▅▃█▁▄▂▆▄▃▃▄▁▄▃▄▅▃▇▅▄
wandb: train/metrics/forces_rmse ▅▃▆▇▆▂▄█▇█▄▅▁▃▅▅▇▅▂▃▅▄█▁▄▂▅▄▃▃▄▁▄▄▃▆▃▆▅▄
wandb: train/metrics/per_atom_energy_rmse ▄▄▇▂█▄▇▅▇▄▄▇▅▅▅▂▃▃▅▅▃▂▆▄▃▄▇▃▃▂▃▃▄▁▅▂▃▆▆▄
wandb: trainer/global_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
wandb: valid/loss/forces_rmse_weighted █▅▃▄▃▂▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
wandb: valid/loss/per_atom_energy_rmse_weighted █▄▄▅▄▃▃▄▅▂█▂▃▃▂▂▃▂▂▂▂▁▁▂▂▁▂▁▁▁▁▁▁▁▁▂▁▂▁▁
wandb: valid/loss/total █▅▄▄▄▃▃▃▄▂▅▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁
wandb: valid/metrics/energy_rmse ▅▃▃▆▄▂▄▃▅▂█▂▃▂▂▁▄▂▂▁▂▁▂▂▃▂▃▂▁▂▂▂▁▁▁▂▂▂▂▂
wandb: valid/metrics/forces_rmse █▅▃▄▃▂▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
wandb: valid/metrics/per_atom_energy_rmse █▄▄▅▄▃▃▄▅▂█▂▃▃▂▂▃▂▂▂▂▁▁▂▂▁▂▁▁▁▁▁▁▁▁▂▁▂▁▁
wandb: valid/metrics/stress_rmse █▄▄▅▃▄▁▃▅▄▇▃▃▃▆▂▂▄▅▃▃▃▃▄▃▃▃▄▃▃▃▂▃▂▃▃▃▃▂▃
wandb: valid/metrics/virial_rmse █▄▄▅▂▄▁▂▅▄▆▃▃▃▆▃▁▄▅▂▂▃▃▃▄▃▃▅▄▃▃▂▂▂▃▃▃▃▂▃
wandb:
wandb: Run summary:
wandb: epoch 149
wandb: lr-AdamW/non-decayable 0.00019
wandb: lr-AdamW/normal 0.00019
wandb: n_learnable_parameters 4233
wandb: n_parameters 4233
wandb: timer/its_per_s/train 18.51852
wandb: timer/its_per_s/valid 37.2592
wandb: timer/step_duration_ms/train 54.0
wandb: timer/step_duration_ms/valid 27.4
wandb: train/loss/forces_rmse_weighted 0.68451
wandb: train/loss/per_atom_energy_rmse_weighted 0.06854
wandb: train/loss/total 0.75305
wandb: train/metrics/forces_rmse 0.68451
wandb: train/metrics/per_atom_energy_rmse 0.06854
wandb: trainer/global_step 2999
wandb: valid/loss/forces_rmse_weighted 0.66686
wandb: valid/loss/per_atom_energy_rmse_weighted 0.06777
wandb: valid/loss/total 0.73462
wandb: valid/metrics/energy_rmse 3.49672
wandb: valid/metrics/forces_rmse 0.66686
wandb: valid/metrics/per_atom_energy_rmse 0.06777
wandb: valid/metrics/stress_rmse 0.02758
wandb: valid/metrics/virial_rmse 12.96325
wandb:
wandb: 🚀 View run finetune-cgap20u at: https://wandb.ai/jla-gardner/graph-pes-quickstart/runs/finetune-cgap20u
wandb: ⭐️ View project at: https://wandb.ai/jla-gardner/graph-pes-quickstart
wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: graph-pes-results/wandb/run-20241205_081944-finetune-cgap20u/logs
This fine-tuning process has aligned the model predictions with the C-GAP-20U
formation energies:
[21]:
fine_tuned_model = load_model("graph-pes-results/finetune-cgap20u/model.pt")
parity_plot(
fine_tuned_model,
cgap20_test_graphs,
property="energy",
transform=divide_per_atom,
units="eV / atom",
lw=0,
s=12,
color="crimson",
)
plt.xlim(-8, -5)
plt.ylim(-8, -5);
