Train a model

FYI, you can open this documentation as a Google Colab notebook to follow along interactively

graph-pes-train provides a unified interface to train any GraphPESModel, including those packaged within graph_pes.models and those defined by you, the user.

For more information on the graph-pes-train command, and the plethora of options available for specification in your config.yaml see the CLI reference.

Below, we train a lightweight NequIP model on the C-GAP-17 dataset.

Installation

[1]:
!pip install graph-pes
Successfully installed graph-pes-0.0.7

We now should have access to the graph-pes-train command. We can check this by running:

[1]:
!graph-pes-train -h
usage: graph-pes-train [-h] [args ...]

Train a GraphPES model using PyTorch Lightning.

positional arguments:
  args        Config files and command line specifications. Config files
              should be YAML (.yaml/.yml) files. Command line specifications
              should be in the form my/nested/key=value. Final config is built
              up from these items in a left to right manner, with later items
              taking precedence over earlier ones in the case of conflicts.
              The data2objects package is used to resolve references and
              create objects directly from the config dictionary.

options:
  -h, --help  show this help message and exit

Copyright 2023-24, John Gardner

Data definition

We use load-atoms to download and split the C-GAP-17 dataset into training, validation and test datasets:

[2]:
import ase.io
from load_atoms import load_dataset

structures = load_dataset("C-GAP-17")
train, val, test = structures.random_split([0.8, 0.1, 0.1])

ase.io.write("train-cgap17.xyz", train)
ase.io.write("val-cgap17.xyz", val)
ase.io.write("test-cgap17.xyz", test)

We can visualise the kinds of structures we’re training on using load_atoms.view:

[3]:
from load_atoms import view

view(train[0], show_bonds=True)
[3]:

Configuration

Great - now lets train a model. To do this, we have specified the following in our quickstart-cgap17.yaml file:

  • the model architecture to instantiate and train, here NequIP. Note that we also include a FixedOffset component to account for the fact that the C-GAP-17 labels have an arbitrary offset.

  • the data to train on, here the C-GAP-17 dataset we just downloaded

  • the loss function to use, here a combination of a per-atom energy loss and a per-atom force loss

  • and various other training hyperparameters (e.g. the learning rate, batch size, etc.)

quickstart-cgap17.yaml
general:
    run_id: quickstart-cgap17
    progress: logged

# train a lightweight NequIP model ...
model:
    offset:
        # note the "+" prefix syntax: refer to the
        # data2objects package for more details
        +FixedOffset: { C: -148.314002 }
    many-body:
        +NequIP:
            elements: [C]
            cutoff: 3.7 # radial cutoff in Å
            layers: 2
            features:
                channels: [16, 8, 4]
                l_max: 2
                use_odd_parity: true
            self_interaction: linear

# ... on structures from local files ...
data:
    train:
        +file_dataset:
            path: train-cgap17.xyz
            cutoff: 3.7
            n: 1280
            shuffle: false
    valid:
        +file_dataset:
            path: val-cgap17.xyz
            cutoff: 3.7

# ... on both energy and forces (weighted 1:1) ...
loss:
    - +PerAtomEnergyLoss()
    - +PropertyLoss: { property: forces, metric: RMSE }

# ... with the following settings ...
fitting:
    trainer_kwargs:
        max_epochs: 250
        accelerator: auto
        check_val_every_n_epoch: 5

    optimizer:
        +Optimizer:
            name: AdamW
            lr: 0.01

    scheduler:
        +LRScheduler:
            name: ReduceLROnPlateau
            factor: 0.5
            patience: 10

    loader_kwargs:
        batch_size: 64

# ... and log to Weights & Biases
wandb:
    project: graph-pes-quickstart

We can download this config file using wget:

[4]:
%%bash

if [ ! -f quickstart-cgap17.yaml ]; then
    wget https://tinyurl.com/graph-pes-quickstart-cgap17 -O quickstart-cgap17.yaml
fi

Training

You can see the output of the training run below in this Weights and Biases dashboard.

[6]:
!graph-pes-train quickstart-cgap17.yaml
[graph-pes INFO]: Started `graph-pes-train` at 2024-12-05 08:09:05.603
[graph-pes INFO]: Successfully parsed config.
[graph-pes INFO]: ID for this training run: quickstart-cgap17
[graph-pes INFO]:
Output for this training run can be found at:
   └─ graph-pes-results/quickstart-cgap17
      ├─ logs/rank-0.log    # find a verbose log here
      ├─ model.pt           # the best model
      ├─ lammps_model.pt    # the best model deployed to LAMMPS
      └─ train-config.yaml  # the complete config used for this run

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: jla-gardner. Use `wandb login --relogin` to force relogin
wandb: Run data is saved locally in graph-pes-results/wandb/run-20241205_080908-quickstart-cgap17
wandb: Run `wandb offline` to turn off syncing.
wandb: Starting run quickstart-cgap17
wandb: ⭐️ View project at https://wandb.ai/jla-gardner/graph-pes-quickstart
wandb: 🚀 View run at https://wandb.ai/jla-gardner/graph-pes-quickstart/runs/quickstart-cgap17
[graph-pes INFO]: Logging to graph-pes-results/quickstart-cgap17/logs/rank-0.log
[graph-pes INFO]: Preparing data
[graph-pes INFO]: Setting up datasets
[graph-pes INFO]: Pre-fitting the model on 1,280 samples
[graph-pes INFO]:
Number of learnable params:
    offset (FixedOffset): 0
    many-body (NequIP)  : 4,233

[graph-pes INFO]: Sanity checking the model...
[graph-pes INFO]: Starting fit...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
                            valid/metrics   valid/metrics   valid/metrics   timer/its_per_s   timer/its_per_s
   epoch      time   per_atom_energy_rmse     forces_rmse     energy_rmse             train             valid
       5       9.5                0.28588         1.63375        17.92273          21.73913          40.94640
      10      18.7                0.20442         1.03626        14.99166          18.86792          41.89713
      15      28.0                0.09091         0.95953         4.80609          18.86792          41.28917
      20      37.2                0.12216         0.91683         7.97960          17.85714          41.64530
      25      46.5                0.12741         0.88628         8.89986          20.40816          41.46724
      30      56.0                0.10043         0.87486         5.06527          20.00000          41.30190
      35      65.4                0.07815         0.86406         4.22427          18.86792          41.46724
      40      74.6                0.18409         0.85244        13.32946          19.23077          42.25326
      45      83.6                0.12486         0.84220         8.04440          18.18182          42.24054
      50      92.5                0.07419         0.83882         3.85131          18.86792          42.06248
      55     101.8                0.10297         0.83181         6.28646          18.51852          42.03170
      60     110.7                0.07544         0.82551         3.95342          17.85714          42.22400
      65     120.3                0.07011         0.82094         3.56080          18.51852          41.47995
      70     129.7                0.07098         0.82106         3.78168          16.39344          40.60300
      75     138.7                0.06978         0.81274         3.71009          18.51852          41.46724
      80     148.0                0.11663         0.82113         7.93304          18.18182          41.49420
      85     157.0                0.06806         0.81087         3.54228          18.86792          41.65954
      90     166.3                0.10854         0.82078         7.23582          18.18182          41.88441
      95     175.3                0.09466         0.80274         6.01067          17.85714          41.85363
     100     184.2                0.08501         0.79980         5.04500          20.40816          41.30189
     105     193.1                0.07817         0.80133         4.89728          19.60784          41.64530
     110     202.0                0.11911         0.80314         7.73768          18.51852          42.81924
     115     211.0                0.09696         0.79252         6.07346          18.18182          41.89713
     120     219.9                0.10645         0.79790         6.85942          18.86792          42.46312
     125     228.8                0.06499         0.79187         3.38392          18.18182          41.28917
     130     238.1                0.09359         0.78877         6.07430          20.00000          42.07520
     135     247.0                0.06430         0.78919         3.63812          18.51852          42.09275
     140     256.3                0.06514         0.79163         3.47451          19.23077          41.13655
     145     265.2                0.07644         0.78685         4.21608          20.40816          42.06248
     150     274.2                0.07534         0.78641         4.64563          17.85714          41.65802
     155     283.1                0.09428         0.79052         6.07322          18.51852          41.67557
     160     292.1                0.10394         0.78160         6.46291          19.60784          41.65954
     165     301.0                0.06789         0.79047         3.91855          18.86792          41.46724
     170     310.0                0.07481         0.77929         4.38113          17.54386          41.67557
     175     318.9                0.05980         0.77984         3.24851          19.23077          41.46724
     180     328.2                0.08898         0.78014         5.88639          18.86792          42.04594
     185     337.1                0.12437         0.78408         8.49052          18.86792          41.86788
     190     346.1                0.09237         0.78591         5.70621          18.18182          41.82336
     195     355.1                0.06462         0.77582         3.55072          17.85714          39.97809
     200     364.1                0.10566         0.77532         6.90219          18.51852          41.86788
     205     373.1                0.11203         0.80174         7.09306          21.27660          42.62694
     210     382.0                0.08404         0.77830         5.35065          20.83333          41.88441
     215     391.0                0.06217         0.77353         3.57670          18.86792          41.49751
     220     400.3                0.06551         0.77113         3.60655          18.51852          42.43285
     225     409.2                0.11056         0.77826         7.54414          18.86792          41.88441
     230     418.2                0.06231         0.77460         3.72732          18.18182          41.64530
     235     427.1                0.07255         0.77408         4.69319          20.00000          41.46724
     240     436.0                0.08572         0.76910         5.88647          17.85714          41.83760
     245     445.0                0.06867         0.76606         4.02220          18.86792          41.65802
     250     454.2                0.06577         0.76767         3.78755          18.51852          41.64530
`Trainer.fit` stopped: `max_epochs=250` reached.
[graph-pes INFO]: Loading best weights from "graph-pes-results/quickstart-cgap17/checkpoints/best.ckpt"
[graph-pes INFO]: Training complete. Awaiting final Lightning and W&B shutdown...
wandb: \ 0.015 MB of 0.015 MB uploaded
wandb: Run history:
wandb:                                    epoch ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:                   lr-AdamW/non-decayable ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                          lr-AdamW/normal ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                   n_learnable_parameters ▁
wandb:                             n_parameters ▁
wandb:                    timer/its_per_s/train █▃▅▃▄▃▅▄▄▇▄▁▄▅▄▅▄▄▄▄▄▄▅▅▄▆▅▄▅▅▄▆█▇▄▅▄▆▄▄
wandb:                    timer/its_per_s/valid ▃▆▄▅▄▅▇▇▆▇▅▃▅▅▆▆▅█▆▇▆▆▄▆▅▅▅▅▆▆▆▁█▆▅▇▅▅▆▅
wandb:             timer/step_duration_ms/train ▁▆▄▅▅▅▃▅▅▂▅█▅▄▅▃▅▅▅▅▅▅▃▃▅▃▄▅▄▄▅▃▁▁▅▄▅▃▅▅
wandb:             timer/step_duration_ms/valid ▅▅▅▄▅▅▃▃▂▂▅▇▅▄▅▃▄▁▅▂▄▄▆▄▄▄▅▄▂▃▃█▂▅▅▂▄▅▃▄
wandb:          train/loss/forces_rmse_weighted █▃▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: train/loss/per_atom_energy_rmse_weighted █▅▃▂▄▃▂▂▃▂▂▂▂▁▂▁▁▁▁▁▁▂▁▂▂▂▁▂▂▂▂▁▂▁▁▁▁▂▂▂
wandb:                         train/loss/total █▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                train/metrics/forces_rmse █▃▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:       train/metrics/per_atom_energy_rmse █▅▃▂▄▃▂▂▃▂▂▂▂▁▂▁▁▁▁▁▁▂▁▂▂▂▁▂▂▂▂▁▂▁▁▁▁▂▂▂
wandb:                      trainer/global_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:          valid/loss/forces_rmse_weighted █▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: valid/loss/per_atom_energy_rmse_weighted █▅▂▃▂▁▅▃▂▁▁▁▃▁▂▂▂▃▂▂▂▁▁▁▂▂▁▁▂▃▂▁▃▂▁▁▁▁▂▁
wandb:                         valid/loss/total █▄▂▂▂▂▂▂▂▁▁▁▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁
wandb:                valid/metrics/energy_rmse █▇▂▃▂▁▆▃▂▁▁▁▃▁▃▂▂▃▂▃▂▁▁▁▂▂▁▁▂▃▂▁▃▂▁▁▁▂▂▁
wandb:                valid/metrics/forces_rmse █▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:       valid/metrics/per_atom_energy_rmse █▅▂▃▂▁▅▃▂▁▁▁▃▁▂▂▂▃▂▂▂▁▁▁▂▂▁▁▂▃▂▁▃▂▁▁▁▁▂▁
wandb:
wandb: Run summary:
wandb:                                    epoch 249
wandb:                   lr-AdamW/non-decayable 0.01
wandb:                          lr-AdamW/normal 0.01
wandb:                   n_learnable_parameters 4233
wandb:                             n_parameters 4234
wandb:                    timer/its_per_s/train 18.51852
wandb:                    timer/its_per_s/valid 41.6453
wandb:             timer/step_duration_ms/train 54.0
wandb:             timer/step_duration_ms/valid 24.875
wandb:          train/loss/forces_rmse_weighted 0.78016
wandb: train/loss/per_atom_energy_rmse_weighted 0.08515
wandb:                         train/loss/total 0.86531
wandb:                train/metrics/forces_rmse 0.78016
wandb:       train/metrics/per_atom_energy_rmse 0.08515
wandb:                      trainer/global_step 4999
wandb:          valid/loss/forces_rmse_weighted 0.76767
wandb: valid/loss/per_atom_energy_rmse_weighted 0.06577
wandb:                         valid/loss/total 0.83344
wandb:                valid/metrics/energy_rmse 3.78755
wandb:                valid/metrics/forces_rmse 0.76767
wandb:       valid/metrics/per_atom_energy_rmse 0.06577
wandb:
wandb: 🚀 View run quickstart-cgap17 at: https://wandb.ai/jla-gardner/graph-pes-quickstart/runs/quickstart-cgap17
wandb: ⭐️ View project at: https://wandb.ai/jla-gardner/graph-pes-quickstart
wandb: Synced 4 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: graph-pes-results/wandb/run-20241205_080908-quickstart-cgap17/logs

Model analysis

Let’s load the best model from the above training run and evaluate it on the test dataset:

[7]:
from graph_pes.models import load_model

best_model = load_model("graph-pes-results/quickstart-cgap17/model.pt")
best_model
[7]:
AdditionModel(
  offset=FixedOffset(trainable=False),
  many-body=NequIP(
    (Z_embedding): AtomicOneHot(elements=['C'])
    (initial_node_embedding): PerElementEmbedding(dim=16, elements=[])
    (edge_embedding): SphericalHarmonics(1x1o -> 1x0e+1x1o+1x2e)
    (layers): UniformModuleList(
      (0): NequIPMessagePassingLayer(
        (pre_message_linear): Linear(16x0e -> 16x0e | 256 weights)
        (message_tensor_product): TensorProduct(16x0e x 1x0e+1x1o+1x2e -> 16x0e+16x1o+16x2e | 48 paths | 48 weights)
        (weight_generator): HaddamardProduct(
          (components): ModuleList(
            (0): Sequential(
              (0): Bessel(n_features=8, cutoff=3.700000047683716, trainable=True)
              (1): MLP(8 → 8 → 8 → 48, activation=SiLU())
            )
            (1): PolynomialEnvelope(cutoff=3.7, p=6)
          )
        )
        (aggregation): SumNeighbours()
        (non_linearity): Gate (28x0e+8x1o+4x2e -> 16x0e+8x1o+4x2e)
        (post_message_linear): Linear(16x0e+16x1o+16x2e -> 28x0e+8x1o+4x2e | 640 weights)
        (self_interaction): LinearSelfInteraction(
          (linear): Linear(16x0e -> 28x0e+8x1o+4x2e | 448 weights)
        )
      )
      (1): NequIPMessagePassingLayer(
        (pre_message_linear): Linear(16x0e+8x1o+4x2e -> 16x0e+8x1o+4x2e | 336 weights)
        (message_tensor_product): TensorProduct(16x0e+8x1o+4x2e x 1x0e+1x1o+1x2e -> 28x0e+36x1o+12x1e+12x2o+32x2e | 120 paths | 120 weights)
        (weight_generator): HaddamardProduct(
          (components): ModuleList(
            (0): Sequential(
              (0): Bessel(n_features=8, cutoff=3.700000047683716, trainable=True)
              (1): MLP(8 → 8 → 8 → 120, activation=SiLU())
            )
            (1): PolynomialEnvelope(cutoff=3.7, p=6)
          )
        )
        (aggregation): SumNeighbours()
        (non_linearity): Gate (16x0e -> 16x0e)
        (post_message_linear): Linear(28x0e+36x1o+12x1e+12x2o+32x2e -> 16x0e | 448 weights)
        (self_interaction): LinearSelfInteraction(
          (linear): Linear(16x0e+8x1o+4x2e -> 16x0e | 256 weights)
        )
      )
    )
    (energy_readout): LinearReadOut(16x0e -> 1x0e | 16 weights)
    (scaler): LocalEnergiesScaler(trainable=True)
  )
)

GraphPESModel act on AtomicGraph objects.

We can easily convert our ase.Atoms objects into AtomicGraph objects using AtomicGraph.from_ase (we could also use the GraphPESCalculator to act directly on the ase.Atoms objects if we wanted to).

[8]:
from graph_pes.atomic_graph import AtomicGraph

test_graphs = [
    AtomicGraph.from_ase(structure, cutoff=3.7) for structure in test
]
test_graphs[0]
[8]:
AtomicGraph(
    atoms=64,
    edges=1124,
    has_cell=True,
    cutoff=3.7,
    properties=['energy', 'forces']
)

Our predictions look like this:

[9]:
{
    k: v.shape
    for k, v in best_model.get_all_PES_predictions(test_graphs[0]).items()
}
[9]:
{'energy': torch.Size([]),
 'forces': torch.Size([64, 3]),
 'local_energies': torch.Size([64]),
 'stress': torch.Size([3, 3]),
 'virial': torch.Size([3, 3])}

We can see from a single data point that our model has done a reasonable job of learning the potential:

[10]:
best_model.predict_energy(test_graphs[0]), test_graphs[0].properties["energy"]
[10]:
(tensor(-9994.0742), tensor(-9998.7080))

graph-pes provides a few utility functions for visualising model performance:

[11]:
import matplotlib.pyplot as plt

from graph_pes.atomic_graph import divide_per_atom
from graph_pes.utils.analysis import parity_plot

%config InlineBackend.figure_format = 'retina'

parity_plot(
    best_model,
    test_graphs,
    property="energy",
    transform=divide_per_atom,
    units="eV / atom",
    lw=0,
    s=12,
    color="crimson",
)
plt.xlim(-158.5, -155)
plt.ylim(-158.5, -155);
../_images/quickstart_quickstart_24_0.png
[12]:
parity_plot(
    best_model,
    test_graphs,
    property="forces",
    units="eV / Å",
    lw=0,
    s=2,
    alpha=0.5,
    color="crimson",
)
../_images/quickstart_quickstart_25_0.png
[13]:
from graph_pes.utils.analysis import dimer_curve

dimer_curve(best_model, system="CC", units="eV", rmin=0.7, rmax=4.0);
../_images/quickstart_quickstart_26_0.png

Fine-tuning

Let’s now take the model we trained above, and fine-tune it on the C-GAP-20U dataset.

[14]:
import ase.io
from load_atoms import load_dataset

structures = load_dataset("C-GAP-20U")
train, val, test = structures.random_split([0.8, 0.1, 0.1])

ase.io.write("train-cgap20u.xyz", train)
ase.io.write("val-cgap20u.xyz", val)
ase.io.write("test-cgap20u.xyz", test)

We can see that the C-GAP-20U dataset clearly has labels with a different arbitrary offset to the C-GAP-17 dataset.

[15]:
cgap20_test_graphs = [
    AtomicGraph.from_ase(structure, cutoff=3.7) for structure in test
]

parity_plot(
    best_model,
    cgap20_test_graphs,
    property="energy",
    transform=divide_per_atom,
    units="eV / atom",
)
../_images/quickstart_quickstart_30_0.png

In fact, the energy labels on C-GAP-20U are formation energies, and hence the offset we used above is no longer necessary:

[16]:
from graph_pes.models import AdditionModel

assert isinstance(best_model, AdditionModel)
underlying_nequip = best_model["many-body"]
type(underlying_nequip)
[16]:
graph_pes.models.e3nn.nequip.NequIP
[17]:
parity_plot(
    underlying_nequip,
    cgap20_test_graphs,
    property="energy",
    transform=divide_per_atom,
    units="eV / atom",
    lw=0,
    s=12,
    color="crimson",
)
../_images/quickstart_quickstart_33_0.png

Here we can see that the pre-trained NequIP model is over-predicting the formation energies: this is to be expected as the DFT codes used to label C-GAP-17 and C-GAP-20U are different.

Training run

To fine-tune a GraphPESModel, we can use the same graph-pes-train command, but with a modified config file where we explicitly load in the model we want to fine-tune using either load_model or load_model_component.

finetune-cgap20u.yaml
general:
    run_id: finetune-cgap20u
    progress: logged

# finetune a pre-trained NequIP model ...
model:
    +load_model_component:
        path: <insert path to model>
        key: many-body

# ... on structures from local files ...
data:
    train:
        +file_dataset:
            # take the first 1280 structures from train-cgap20u.xyz
            path: train-cgap20u.xyz
            cutoff: 3.7
            n: 1280
            shuffle: false
    valid:
        +file_dataset:
            # use all structures from val-cgap20u.xyz
            path: val-cgap20u.xyz
            cutoff: 3.7

# ... on both energy and forces ...
loss:
    - +PerAtomEnergyLoss()
    - +PropertyLoss: { property: forces, metric: RMSE }

# ... with the following settings ...
fitting:
    trainer_kwargs:
        max_epochs: 150
        accelerator: auto

    pre_fit_model: false

    optimizer:
        +Optimizer:
            name: AdamW
            lr: 0.003

    scheduler:
        +LRScheduler:
            name: ReduceLROnPlateau
            factor: 0.5
            patience: 10

    loader_kwargs:
        batch_size: 64

# ... and log to Weights & Biases
wandb:
    project: graph-pes-quickstart
[18]:
%%bash

if [ ! -f finetune-cgap20u.yaml ]; then
    wget https://tinyurl.com/graph-pes-finetune-cgap20u -O finetune-cgap20u.yaml
fi
[19]:
!graph-pes-train finetune-cgap20u.yaml model/+load_model_component/path=graph-pes-results/quickstart-cgap17/model.pt  # noqa: E501
[graph-pes INFO]: Started `graph-pes-train` at 2024-12-05 08:19:40.426
[graph-pes INFO]: Successfully parsed config.
[graph-pes INFO]: ID for this training run: finetune-cgap20u
[graph-pes INFO]:
Output for this training run can be found at:
   └─ graph-pes-results/finetune-cgap20u
      ├─ logs/rank-0.log    # find a verbose log here
      ├─ model.pt           # the best model
      ├─ lammps_model.pt    # the best model deployed to LAMMPS
      └─ train-config.yaml  # the complete config used for this run

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: jla-gardner. Use `wandb login --relogin` to force relogin
wandb: Run data is saved locally in graph-pes-results/wandb/run-20241205_081944-finetune-cgap20u
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run finetune-cgap20u
wandb: ⭐️ View project at https://wandb.ai/jla-gardner/graph-pes-quickstart
wandb: 🚀 View run at https://wandb.ai/jla-gardner/graph-pes-quickstart/runs/finetune-cgap20u
[graph-pes INFO]: Logging to graph-pes-results/finetune-cgap20u/logs/rank-0.log
[graph-pes INFO]: Preparing data
[graph-pes INFO]: Caching neighbour lists for 1280 structures with cutoff 3.7
[graph-pes INFO]: Caching neighbour lists for 608 structures with cutoff 3.7
[graph-pes INFO]: Setting up datasets
[graph-pes INFO]: Number of learnable params : 4,233
[graph-pes INFO]: Sanity checking the model...
[graph-pes INFO]: Starting fit...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
                            valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   timer/its_per_s   timer/its_per_s
   epoch      time   per_atom_energy_rmse     forces_rmse     energy_rmse     stress_rmse     virial_rmse             train             valid
       5      12.0                0.07865         0.68254         3.96113         0.03609        17.63322          23.80952          37.49900
      10      22.7                0.07496         0.67690         3.82261         0.03304        17.34807          19.23077          36.97446
      15      33.2                0.08033         0.67894         4.78717         0.03800        19.28491          17.54386          37.11675
      20      44.0                0.07452         0.67271         3.73615         0.03089        14.90855          19.60784          37.52057
      25      54.8                0.08204         0.67367         5.36235         0.03168        15.47247          19.23077          37.50641
      30      65.3                0.07406         0.67675         4.30161         0.02724        12.27123          20.00000          37.43886
      35      75.5                0.07156         0.67334         3.95478         0.02931        15.73284          18.18182          37.22427
      40      86.4                0.07489         0.67066         3.60348         0.03342        18.05131          19.23077          37.58847
      45      96.6                0.07079         0.67161         3.41603         0.03000        14.41076          18.18182          37.63869
      50     107.1                0.07305         0.67295         3.56000         0.02818        13.76435          19.60784          37.49900
      55     117.3                0.07023         0.67409         3.76528         0.02724        13.51196          18.18182          37.12831
      60     127.8                0.07088         0.67273         4.01937         0.02963        14.29006          22.22222          37.25920
      65     138.3                0.06981         0.67033         3.71973         0.02761        12.70107          18.86792          37.24326
      70     148.5                0.07237         0.67212         3.86753         0.03468        18.37819          19.23077          36.85359
      75     159.0                0.07002         0.67110         3.79805         0.02902        13.88428          18.51852          36.54050
      80     169.6                0.06927         0.67006         3.40840         0.02902        13.61990          18.86792          36.96290
      85     179.7                0.07055         0.67000         3.48576         0.02826        14.21006          19.60784          37.37414
      90     189.9                0.07125         0.66895         3.87585         0.03137        16.16199          18.86792          37.37414
      95     201.1                0.07019         0.66909         3.95231         0.02830        13.82447          19.23077          37.37414
     100     211.2                0.07089         0.66866         4.09286         0.02903        14.54211          19.60784          37.13604
     105     221.8                0.06771         0.66860         3.38274         0.02796        13.15577          19.23077          37.28501
     110     232.0                0.07259         0.66760         4.15963         0.03054        15.05223          18.51852          37.11675
     115     242.2                0.06952         0.66778         3.75120         0.02948        14.06543          17.54386          37.01168
     120     252.8                0.06805         0.66722         3.30765         0.02783        13.14172          19.60784          37.25009
     125     264.0                0.06787         0.66716         3.33523         0.02702        12.45630          17.54386          37.03919
     130     274.9                0.06773         0.66712         3.33990         0.02738        12.71667          17.85714          37.01623
     135     285.1                0.06798         0.66733         3.37465         0.02904        13.94043          18.18182          36.27171
     140     295.7                0.06882         0.66725         3.70164         0.02855        13.71589          20.00000          37.25920
     145     305.8                0.06764         0.66698         3.48667         0.02762        13.05017          19.60784          36.99604
     150     316.6                0.06777         0.66686         3.49672         0.02758        12.96325          18.51852          37.25920
`Trainer.fit` stopped: `max_epochs=150` reached.
[graph-pes INFO]: Loading best weights from "graph-pes-results/finetune-cgap20u/checkpoints/best.ckpt"
[graph-pes INFO]: Training complete. Awaiting final Lightning and W&B shutdown...
wandb: \ 0.041 MB of 0.041 MB uploaded
wandb: Run history:
wandb:                                    epoch ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:                   lr-AdamW/non-decayable ███████████████▄▄▄▄▄▄▄▄▄▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
wandb:                          lr-AdamW/normal ███████████████▄▄▄▄▄▄▄▄▄▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
wandb:                   n_learnable_parameters ▁
wandb:                             n_parameters ▁
wandb:                    timer/its_per_s/train ▅█▃▂▃▄▃▃▃▂▃▂▁▄▂▂▃▃▃▂▂▃▄▂▄▃▄▂▃▂▁▃▄▁▂▃▂▄▄▂
wandb:                    timer/its_per_s/valid ▄▁▅▅▅▇▄▅▇▆▅▅▄▇▃▆▄▆▂▅▇▃▄▄▃▃▄▄▅▃▄▆█▆▃▇█▇▆▅
wandb:             timer/step_duration_ms/train ▃▁▅▆▆▅▅▆▅▇▅▇█▅▇▇▆▆▅▆▇▆▅▇▅▅▅▇▆▆█▅▅█▇▅▆▅▅▆
wandb:             timer/step_duration_ms/valid ▆█▃▄▄▂▆▄▂▃▄▃▆▂▆▃▄▃▇▄▂▇▄▄▇▇▆▄▃▆▆▃▁▃▇▂▁▂▃▄
wandb:          train/loss/forces_rmse_weighted ▅▃▆▇▆▂▄█▇█▄▅▁▃▅▅▇▅▂▃▅▄█▁▄▂▅▄▃▃▄▁▄▄▃▆▃▆▅▄
wandb: train/loss/per_atom_energy_rmse_weighted ▄▄▇▂█▄▇▅▇▄▄▇▅▅▅▂▃▃▅▅▃▂▆▄▃▄▇▃▃▂▃▃▄▁▅▂▃▆▆▄
wandb:                         train/loss/total ▅▃▇▆▇▃▅█▇▇▄▆▂▃▅▄▆▄▂▃▅▃█▁▄▂▆▄▃▃▄▁▄▃▄▅▃▇▅▄
wandb:                train/metrics/forces_rmse ▅▃▆▇▆▂▄█▇█▄▅▁▃▅▅▇▅▂▃▅▄█▁▄▂▅▄▃▃▄▁▄▄▃▆▃▆▅▄
wandb:       train/metrics/per_atom_energy_rmse ▄▄▇▂█▄▇▅▇▄▄▇▅▅▅▂▃▃▅▅▃▂▆▄▃▄▇▃▃▂▃▃▄▁▅▂▃▆▆▄
wandb:                      trainer/global_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
wandb:          valid/loss/forces_rmse_weighted █▅▃▄▃▂▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
wandb: valid/loss/per_atom_energy_rmse_weighted █▄▄▅▄▃▃▄▅▂█▂▃▃▂▂▃▂▂▂▂▁▁▂▂▁▂▁▁▁▁▁▁▁▁▂▁▂▁▁
wandb:                         valid/loss/total █▅▄▄▄▃▃▃▄▂▅▃▂▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                valid/metrics/energy_rmse ▅▃▃▆▄▂▄▃▅▂█▂▃▂▂▁▄▂▂▁▂▁▂▂▃▂▃▂▁▂▂▂▁▁▁▂▂▂▂▂
wandb:                valid/metrics/forces_rmse █▅▃▄▃▂▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
wandb:       valid/metrics/per_atom_energy_rmse █▄▄▅▄▃▃▄▅▂█▂▃▃▂▂▃▂▂▂▂▁▁▂▂▁▂▁▁▁▁▁▁▁▁▂▁▂▁▁
wandb:                valid/metrics/stress_rmse █▄▄▅▃▄▁▃▅▄▇▃▃▃▆▂▂▄▅▃▃▃▃▄▃▃▃▄▃▃▃▂▃▂▃▃▃▃▂▃
wandb:                valid/metrics/virial_rmse █▄▄▅▂▄▁▂▅▄▆▃▃▃▆▃▁▄▅▂▂▃▃▃▄▃▃▅▄▃▃▂▂▂▃▃▃▃▂▃
wandb:
wandb: Run summary:
wandb:                                    epoch 149
wandb:                   lr-AdamW/non-decayable 0.00019
wandb:                          lr-AdamW/normal 0.00019
wandb:                   n_learnable_parameters 4233
wandb:                             n_parameters 4233
wandb:                    timer/its_per_s/train 18.51852
wandb:                    timer/its_per_s/valid 37.2592
wandb:             timer/step_duration_ms/train 54.0
wandb:             timer/step_duration_ms/valid 27.4
wandb:          train/loss/forces_rmse_weighted 0.68451
wandb: train/loss/per_atom_energy_rmse_weighted 0.06854
wandb:                         train/loss/total 0.75305
wandb:                train/metrics/forces_rmse 0.68451
wandb:       train/metrics/per_atom_energy_rmse 0.06854
wandb:                      trainer/global_step 2999
wandb:          valid/loss/forces_rmse_weighted 0.66686
wandb: valid/loss/per_atom_energy_rmse_weighted 0.06777
wandb:                         valid/loss/total 0.73462
wandb:                valid/metrics/energy_rmse 3.49672
wandb:                valid/metrics/forces_rmse 0.66686
wandb:       valid/metrics/per_atom_energy_rmse 0.06777
wandb:                valid/metrics/stress_rmse 0.02758
wandb:                valid/metrics/virial_rmse 12.96325
wandb:
wandb: 🚀 View run finetune-cgap20u at: https://wandb.ai/jla-gardner/graph-pes-quickstart/runs/finetune-cgap20u
wandb: ⭐️ View project at: https://wandb.ai/jla-gardner/graph-pes-quickstart
wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: graph-pes-results/wandb/run-20241205_081944-finetune-cgap20u/logs

This fine-tuning process has aligned the model predictions with the C-GAP-20U formation energies:

[21]:
fine_tuned_model = load_model("graph-pes-results/finetune-cgap20u/model.pt")
parity_plot(
    fine_tuned_model,
    cgap20_test_graphs,
    property="energy",
    transform=divide_per_atom,
    units="eV / atom",
    lw=0,
    s=12,
    color="crimson",
)
plt.xlim(-8, -5)
plt.ylim(-8, -5);
../_images/quickstart_quickstart_39_0.png