Fine-tuning

FYI, you can open this documentation as a Google Colab notebook to follow along interactively

graph-pes provides access to several foundation models (with more being added as they are released) for you to use as-is, as well as to fine-tune on your own dataset. See the documentation for a growing list of the interfaces we provide.

Below, we:

  • install the relevant packages to use the MACE-MP-0 and Orb families of models

  • fine-tune the model on structures labelled with a different level of theory

Since graph-pes provides a unified interface to many different foundation models, swapping between these different foundation models is incredibly simple!

Installation

To make use of the MACE-MP-0 and Orb families of models, we need to install mace-torch and orb-models alongside graph-pes:

[ ]:
!pip install graph-pes mace-torch orb-models

Loading and using foundation models

Let’s start by loading the MACE-MP-0-small model and using it to make predictions on some SiO\(_2\) structures. To do this, we’ll use load-atoms to download the SiO2-GAP-22 dataset.

[1]:
from load_atoms import load_dataset

dataset = load_dataset("SiO2-GAP-22")
dataset
[1]:
SiO2-GAP-22:
    structures: 3,074
    atoms: 268,118
    species:
        O: 66.47%
        Si: 33.53%
    properties:
        per atom: (forces)
        per structure: (config_type, energy, stress, virial)

Let’s select the first structre from this dataset and visualise it:

[2]:
from load_atoms import view

structure = dataset[0]
view(structure, show_bonds=True)
[2]:

Now that we have a structure, lets use the graph_pes.interfaces.mace_mp function to load the MACE-MP-0-small model to make some predictions:

[3]:
import torch
from graph_pes.interfaces import mace_mp

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

mp0 = mace_mp("small").eval().to(device)
sum(p.numel() for p in mp0.parameters())
Using device: cpu
[3]:
3847696

This model object is a GraphPESModel instance, and so can be used throughout the rest of the graph-pes ecosystem. For instance, we can inspect the dimer curves for this model:

[4]:
from graph_pes.utils.analysis import dimer_curve
import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'

for dimer, c in [("Si-Si", "crimson"), ("O-O", "green"), ("Si-O", "royalblue")]:
    dimer_curve(
        mp0, dimer.replace("-", ""), units="eV", rmin=0.7, rmax=4.0, label=dimer, c=c
    )
plt.legend();
../_images/quickstart_fine-tuning_10_0.png

… and also use it as an ASE calculator to generate force predictions:

[5]:
from ase import Atoms
from graph_pes.graph_pes_model import GraphPESModel


def check_forces(model: GraphPESModel, structure: Atoms):
    # make an ASE calculator
    calculator = model.ase_calculator()

    force_predictions = calculator.get_forces(structure)
    force_labels = structure.arrays["forces"]

    plt.figure(figsize=(2.5, 2.5))
    plt.scatter(force_labels.flatten(), force_predictions.flatten(), s=6)
    plt.axline((0, 0), slope=1, color="black", ls="--", lw=1)
    plt.gca().set_aspect("equal")
    plt.xlabel("True force / eV/Å")
    plt.ylabel("Predicted force / eV/Å");


check_forces(mp0, dataset[0])
../_images/quickstart_fine-tuning_12_0.png

To demonstrate the architecture-agnostic nature of graph-pes, we can swap between the MACE-MP-0-small and Orb-v3-small models very easily:

[6]:
from graph_pes.interfaces import orb_model

orbv2_xs = orb_model("orb-d3-xs-v2").eval().to(device)
sum(p.numel() for p in orbv2_xs.parameters())
[6]:
9443981
[7]:
check_forces(orbv2_xs, dataset[0])
../_images/quickstart_fine-tuning_15_0.png

It is explicitly not the point here to compare these two models. Instead, we are just demonstrating that these two models, created by different teams and originating from different packages, can be treated identically within the confines of graph-pes.

We now know how to use these foundation models out of the box - they behave just like any other graph-pes model!

Fine-tuning

Often we want to fine-tune foundation models on specific datasets to further improve their accuracy. graph-pes makes this easy!

We saw above that the force predictions for the foundation models were already in close agreement with the DFT labels of the SiO2-GAP-22 model. This is despite the fact that these foundation models were trained on data labelled with different functionals.

However, different functionals tend to have different “reference” energies, i.e., the energy of an isolated atom of element \(X\) will have some non-0 energy, \(\varepsilon_X\) and for functional (a) vs (b), \(\varepsilon_X^{(a)} \neq \varepsilon_X^{(b)}\).

To demonstrate this, we create an energy parity plot below:

[8]:
from graph_pes.utils.analysis import parity_plot

parity_plot(
    mp0,
    dataset[:20],
    property="energy_per_atom",
    units="eV/atom",
    c="crimson",
    label="MP0"
)
parity_plot(
    orbv2_xs,
    dataset[:20],
    property="energy_per_atom",
    units="eV/atom",
    c="royalblue",
    label="Orb"
)
plt.xlabel("Ground truth SCAN Energy (eV/atom)")
plt.ylabel("Predicted Energy (eV/atom)")
plt.legend(frameon=True);
../_images/quickstart_fine-tuning_17_0.png

graph-pes provides automated functionality to correct for exactly these kind of differences:

[9]:
from graph_pes.utils.shift_and_scale import add_auto_offset

adjusted_mp0 = add_auto_offset(mp0, dataset[:20])
adjusted_orb = add_auto_offset(orbv2_xs, dataset[:20])
parity_plot(
    adjusted_mp0,
    dataset[:20],
    property="energy_per_atom",
    units="eV/atom",
    c="crimson",
    label="MP0"
)
parity_plot(
    adjusted_orb,
    dataset[:20],
    property="energy_per_atom",
    units="eV/atom",
    c="royalblue",
    label="Orb"
)
plt.xlabel("Ground truth SCAN Energy (eV/atom)")
plt.ylabel("Predicted Energy (eV/atom)")
plt.legend(frameon=True);

[graph-pes INFO]:
Attempting to automatically detect the offset energy for each element.
We do this by first generating predictions for each training structure
(up to `config.fitting.max_n_pre_fit` if specified).
This is a slow process! If you already know the reference energies (or the
difference in reference energies if you are fine-tuning an existing model to a
different level of theory),
we recommend setting `config.fitting.auto_fit_reference_energies` to `False`
and manually specifying a `LearnableOffset` component of your model.

See the "Fine-tuning" tutorial in the docs for more information:
https://jla-gardner.github.io/graph-pes/quickstart/fine-tuning.html

[graph-pes WARNING]:
We are attempting to guess the mean per-element
contribution for a per-structure quantity (usually
the total energy).

However, the composition of the training set is such that
no unique solution is possible.

This is probably because you are training on structures
all with the same composition (e.g. all structures are
of the form n H2O). Consider explicitly setting the
per-element contributions if you know them, or
including a variety of structures of different
compositions in the training set.

[graph-pes INFO]:
Attempting to automatically detect the offset energy for each element.
We do this by first generating predictions for each training structure
(up to `config.fitting.max_n_pre_fit` if specified).
This is a slow process! If you already know the reference energies (or the
difference in reference energies if you are fine-tuning an existing model to a
different level of theory),
we recommend setting `config.fitting.auto_fit_reference_energies` to `False`
and manually specifying a `LearnableOffset` component of your model.

See the "Fine-tuning" tutorial in the docs for more information:
https://jla-gardner.github.io/graph-pes/quickstart/fine-tuning.html

[graph-pes WARNING]:
We are attempting to guess the mean per-element
contribution for a per-structure quantity (usually
the total energy).

However, the composition of the training set is such that
no unique solution is possible.

This is probably because you are training on structures
all with the same composition (e.g. all structures are
of the form n H2O). Consider explicitly setting the
per-element contributions if you know them, or
including a variety of structures of different
compositions in the training set.

../_images/quickstart_fine-tuning_19_1.png

Note the warning above! It we pass a set of structures with exactly the same composition (here \(n\cdot SiO_2\)), its not possible to decouple the differences in the \(\varepsilon_O\) and \(\varepsilon_{Si}\). This means if you use this new model to predict energies on a new structure with a different compositions, e.g. \(Si_5O_9\), you will get very different energies from what the ground truth can give you.

With that in mind, lets get down to fine-tuning a model. As in our other fine-tuning guide we start by selecting some structures to fine-tune on:

[10]:
train, valid = dataset[:20], dataset[20:25]
train.write("train.xyz")
valid.write("valid.xyz")

Next, we define a config file to describe how what and how we want to fine-tune. Two vital features here are:

  1. the fitting/auto_fit_reference_energies=true flag - this performs the above automated offset guessing before any fine-tuning takes place

  2. using a relatively small learning rate

[20]:
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
cutoff = mp0.cutoff.item()

training_config = f"""
data:
    train:
        +file_dataset:
            path: train.xyz
            cutoff: {cutoff}
    valid:
        +file_dataset:
            path: valid.xyz
            cutoff: {cutoff}

loss:
    - +PerAtomEnergyLoss()
    - +ForceRMSE()

fitting:
    trainer_kwargs:
        max_epochs: 20
        accelerator: {accelerator}

    optimizer:
        +Optimizer:
            name: Adam
            lr: 0.0001

    auto_fit_reference_energies: true

wandb: null
general:
    progress: logged
    run_id: mp0-fine-tune
"""

mp0_config = """
model:
    +mace_mp:
        model: small

general:
    run_id: mp0-fine-tune
"""

with open("fine-tune.yaml", "w") as f:
    f.write(training_config)

with open("mp0.yaml", "w") as f:
    f.write(mp0_config)

Now we can fine-tune. Note the use of config file stacking here - this lets us separate the config into different files, which can often be useful! (see below)

[12]:
!graph-pes-train fine-tune.yaml mp0.yaml
[graph-pes INFO]: Started `graph-pes-train` at 2025-04-11 16:52:15.879
[graph-pes INFO]: Successfully parsed config.
[graph-pes INFO]: Logging to graph-pes-results/mp0-fine-tune/rank-0.log
[graph-pes INFO]: ID for this training run: mp0-fine-tune
[graph-pes INFO]:
Output for this training run can be found at:
   └─ graph-pes-results/mp0-fine-tune
      ├─ rank-0.log         # find a verbose log here
      ├─ model.pt           # the best model (according to valid/loss/total)
      ├─ lammps_model.pt    # the best model deployed to LAMMPS
      ├─ train-config.yaml  # the complete config used for this run
      └─ summary.yaml       # the summary of the training run

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[graph-pes INFO]: Preparing data
[graph-pes INFO]: Setting up datasets
[graph-pes INFO]: Pre-fitting the model on 20 samples
[graph-pes INFO]:
Attempting to automatically detect the offset energy for each element.
We do this by first generating predictions for each training structure
(up to `config.fitting.max_n_pre_fit` if specified).
This is a slow process! If you already know the reference energies (or the
difference in reference energies if you are fine-tuning an existing model to a
different level of theory),
we recommend setting `config.fitting.auto_fit_reference_energies` to `False`
and manually specifying a `LearnableOffset` component of your model.

See the "Fine-tuning" tutorial in the docs for more information:
https://jla-gardner.github.io/graph-pes/quickstart/fine-tuning.html

[graph-pes WARNING]:
We are attempting to guess the mean per-element
contribution for a per-structure quantity (usually
the total energy).

However, the composition of the training set is such that
no unique solution is possible.

This is probably because you are training on structures
all with the same composition (e.g. all structures are
of the form n H2O). Consider explicitly setting the
per-element contributions if you know them, or
including a variety of structures of different
compositions in the training set.

[graph-pes INFO]:
Number of learnable params:
    base (MACEWrapper)           : 3,847,696
    auto_offset (LearnableOffset): 2

[graph-pes INFO]: Sanity checking the model...
[graph-pes INFO]: Starting fit...
                            valid/metrics         valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   timer/its_per_s   timer/its_per_s
   epoch      time   per_atom_energy_rmse   per_atom_energy_mae     energy_rmse      energy_mae     forces_rmse      forces_mae     stress_rmse      stress_mae     virial_rmse      virial_mae             train             valid
       1       7.8                0.00716               0.00673         0.77302         0.72710         0.10150         0.07997         0.00925         0.00584       128.09291       102.27287           0.80064           4.89904
       2      16.1                0.00905               0.00825         0.97704         0.89060         0.05779         0.04414         0.00524         0.00346       128.24873       103.04964           0.70522           4.23109
       3      24.6                0.00174               0.00132         0.18810         0.14233         0.03947         0.02871         0.00277         0.00183       128.95222       103.92886           0.64683           4.50949
       4      32.7                0.00432               0.00419         0.46684         0.45242         0.03426         0.02702         0.00179         0.00117       130.01169       104.82194           0.71276           4.59975
       5      40.1                0.00162               0.00142         0.17489         0.15320         0.03540         0.02763         0.00139         0.00091       130.51797       105.21088           0.78493           4.31182
       6      48.0                0.00332               0.00312         0.35861         0.33645         0.02760         0.02140         0.00186         0.00125       129.84111       104.61426           0.67797           4.33375
       7      56.0                0.00225               0.00195         0.24346         0.21077         0.02652         0.02051         0.00180         0.00125       129.79610       104.55513           0.67476           4.14495
       8      64.0                0.00140               0.00090         0.15127         0.09751         0.02467         0.01905         0.00183         0.00127       129.90742       104.63938           0.75873           3.79053
       9      71.9                0.00206               0.00177         0.22223         0.19067         0.02435         0.01856         0.00168         0.00113       130.40732       105.05786           0.75075           3.73735
      10      79.7                0.00240               0.00231         0.25874         0.24932         0.02326         0.01803         0.00110         0.00075       130.71355       105.33549           0.68776           4.25472
      11      87.3                0.00208               0.00199         0.22486         0.21538         0.02311         0.01779         0.00106         0.00070       130.88290       105.47400           0.76104           4.32935
      12      95.0                0.00122               0.00101         0.13136         0.10957         0.02261         0.01729         0.00115         0.00075       130.93184       105.50145           0.78802           4.60993
      13     101.1                0.00068               0.00040         0.07350         0.04363         0.02162         0.01665         0.00108         0.00072       130.93443       105.49496           1.02987           5.04530
      14     107.4                0.00062               0.00047         0.06713         0.05095         0.02094         0.01612         0.00106         0.00070       131.02118       105.55566           0.76511           6.38889
      15     114.5                0.00092               0.00078         0.09890         0.08396         0.02081         0.01601         0.00087         0.00058       131.20523       105.70125           0.77821           6.11329
      16     123.3                0.00220               0.00217         0.23743         0.23433         0.02170         0.01662         0.00069         0.00045       131.44841       105.88866           0.65963           3.55058
      17     132.7                0.00166               0.00158         0.17911         0.17083         0.01936         0.01502         0.00088         0.00059       131.26625       105.73106           0.54230           3.37123
      18     140.1                0.00130               0.00126         0.14066         0.13652         0.01936         0.01496         0.00062         0.00042       131.49023       105.90979           0.76336           4.44092
      19     147.6                0.00111               0.00109         0.12030         0.11782         0.02065         0.01572         0.00047         0.00031       131.74437       106.11167           0.75529           4.32136
      20     155.2                0.00304               0.00303         0.32831         0.32720         0.01846         0.01435         0.00045         0.00033       131.59532       105.98393           0.80128           5.05800
`Trainer.fit` stopped: `max_epochs=20` reached.
[graph-pes INFO]: Loading best weights from "/Users/john/projects/graph-pes/docs/source/quickstart/graph-pes-results/mp0-fine-tune/checkpoints/best.ckpt"
[graph-pes INFO]: Training complete.
[graph-pes INFO]: Testing best model...
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃           Test metric                       DataLoader 0           ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/train/energy_mae              0.08073119819164276        │
│      test/train/energy_rmse             0.09959671646356583        │
│      test/train/forces_mae              0.014654574915766716       │
│      test/train/forces_rmse              0.0190340057015419        │
│ test/train/forces_rmse_batchwise        0.018963992595672607       │
│  test/train/per_atom_energy_mae        0.0007476329919882119       │
│ test/train/per_atom_energy_rmse        0.0009222375811077654       │
│      test/train/stress_mae             0.00044360198080539703      │
│      test/train/stress_rmse            0.0006201165379025042       │
│      test/train/virial_mae               97.18195343017578         │
│      test/train/virial_rmse              127.16856384277344        │
│      test/valid/energy_mae                                         │
│      test/valid/energy_rmse                                        │
│      test/valid/forces_mae                                         │
│      test/valid/forces_rmse                                        │
│ test/valid/forces_rmse_batchwise                                   │
│  test/valid/per_atom_energy_mae                                    │
│ test/valid/per_atom_energy_rmse                                    │
│      test/valid/stress_mae                                         │
│      test/valid/stress_rmse                                        │
│      test/valid/virial_mae                                         │
│      test/valid/virial_rmse                                        │
└──────────────────────────────────┴──────────────────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃           Test metric                       DataLoader 1           ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/train/energy_mae                                         │
│      test/train/energy_rmse                                        │
│      test/train/forces_mae                                         │
│      test/train/forces_rmse                                        │
│ test/train/forces_rmse_batchwise                                   │
│  test/train/per_atom_energy_mae                                    │
│ test/train/per_atom_energy_rmse                                    │
│      test/train/stress_mae                                         │
│      test/train/stress_rmse                                        │
│      test/train/virial_mae                                         │
│      test/train/virial_rmse                                        │
│      test/valid/energy_mae              0.13652344048023224        │
│      test/valid/energy_rmse             0.10905405879020691        │
│      test/valid/forces_mae              0.014955026097595692       │
│      test/valid/forces_rmse             0.01909947581589222        │
│ test/valid/forces_rmse_batchwise        0.019314901903271675       │
│  test/valid/per_atom_energy_mae         0.001264381455257535       │
│ test/valid/per_atom_energy_rmse        0.0010098782368004322       │
│      test/valid/stress_mae             0.00042308951378799975      │
│      test/valid/stress_rmse            0.0006206676480360329       │
│      test/valid/virial_mae               105.9097900390625         │
│      test/valid/virial_rmse              128.0445556640625         │
└──────────────────────────────────┴──────────────────────────────────┘
[graph-pes INFO]: Testing complete.
[graph-pes INFO]: Awaiting final Lightning and W&B shutdown...

Nice! Fine-tuning on just 20 structures has brought the validation error down significantly 😊 Of course in real life, one would train for longer, and perhaps make use of early stopping to prevent overfitting to the small amounts of data - see the docs for the relevant config options you need to pass to do that.

Lets load in our fine-tuned model and check that everything is working as expected:

[13]:
from graph_pes.models import load_model

test_set = dataset[50:100]

parity_plot(
    adjusted_mp0,
    test_set,
    property="energy_per_atom",
    units="eV/atom",
)
plt.title("Offset-adjusted base MP-0-small")
plt.show()

fine_tuned_mp0 = load_model("graph-pes-results/mp0-fine-tune/model.pt").eval()
parity_plot(
    fine_tuned_mp0,
    test_set,
    property="energy_per_atom",
    units="eV/atom",
)
plt.title("Fine-tuned MP-0-small")
../_images/quickstart_fine-tuning_27_0.png
[13]:
Text(0.5, 1.0, 'Fine-tuned MP-0-small')
../_images/quickstart_fine-tuning_27_2.png

Swapping out the MP-0 model for the Orb model is trivial! We just need to change the model definition in the config file. Here we’re being extra fancy, and freezing all the model parameters except those in the read-out head (see graph_pes.models.freeze_all_except for details):

[24]:
orb_config = """
model:
    +freeze_all_except:
        model:
            +orb_model:
                name: orb-d3-xs-v2
        pattern: _orb\\.heads.*

general:
    run_id: orb-fine-tune
"""

with open("orb.yaml", "w") as f:
    f.write(orb_config)
[25]:
!graph-pes-train fine-tune.yaml orb.yaml
/opt/miniconda3/envs/mace/lib/python3.11/site-packages/e3nn/o3/_wigner.py:10: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.
  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))
[graph-pes INFO]: Started `graph-pes-train` at 2025-04-11 17:16:02.923
/opt/miniconda3/envs/mace/lib/python3.11/site-packages/orb_models/utils.py:30: UserWarning: Setting global torch default dtype to torch.float32.
  warnings.warn(f"Setting global torch default dtype to {torch_dtype}.")
[graph-pes INFO]: Successfully parsed config.
[graph-pes INFO]: Logging to graph-pes-results/orb-fine-tune/rank-0.log
[graph-pes INFO]: ID for this training run: orb-fine-tune
[graph-pes INFO]:
Output for this training run can be found at:
   └─ graph-pes-results/orb-fine-tune
      ├─ rank-0.log         # find a verbose log here
      ├─ model.pt           # the best model (according to valid/loss/total)
      ├─ lammps_model.pt    # the best model deployed to LAMMPS
      ├─ train-config.yaml  # the complete config used for this run
      └─ summary.yaml       # the summary of the training run

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[graph-pes INFO]: Preparing data
[graph-pes INFO]: Setting up datasets
[graph-pes INFO]: Pre-fitting the model on 20 samples
[graph-pes INFO]:
Attempting to automatically detect the offset energy for each element.
We do this by first generating predictions for each training structure
(up to `config.fitting.max_n_pre_fit` if specified).
This is a slow process! If you already know the reference energies (or the
difference in reference energies if you are fine-tuning an existing model to a
different level of theory),
we recommend setting `config.fitting.auto_fit_reference_energies` to `False`
and manually specifying a `LearnableOffset` component of your model.

See the "Fine-tuning" tutorial in the docs for more information:
https://jla-gardner.github.io/graph-pes/quickstart/fine-tuning.html

[graph-pes WARNING]:
We are attempting to guess the mean per-element
contribution for a per-structure quantity (usually
the total energy).

However, the composition of the training set is such that
no unique solution is possible.

This is probably because you are training on structures
all with the same composition (e.g. all structures are
of the form n H2O). Consider explicitly setting the
per-element contributions if you know them, or
including a variety of structures of different
compositions in the training set.

[graph-pes INFO]:
Number of learnable params:
    base (OrbWrapper)            : 200,064
    auto_offset (LearnableOffset): 2

[graph-pes INFO]: Sanity checking the model...
[graph-pes INFO]: Starting fit...
                            valid/metrics         valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   valid/metrics   timer/its_per_s   timer/its_per_s
   epoch      time   per_atom_energy_rmse   per_atom_energy_mae     energy_rmse      energy_mae     forces_rmse      forces_mae     stress_rmse      stress_mae     virial_rmse      virial_mae             train             valid
       1       1.8                0.01108               0.00818         1.19668         0.88313         0.16017         0.12256         0.02922         0.02326        76.28781        61.23480           3.80228           5.45729
       2       3.6                0.00770               0.00585         0.83182         0.63171         0.15506         0.11872         0.02922         0.02326        76.28781        61.23480           4.13223          11.88725
       3       5.3                0.00467               0.00438         0.50436         0.47307         0.15038         0.11539         0.02922         0.02326        76.28781        61.23480           3.86100          10.40802
       4       7.1                0.00369               0.00305         0.39796         0.32947         0.14623         0.11259         0.02922         0.02326        76.28781        61.23480           4.03226          10.95286
       5       8.9                0.00262               0.00232         0.28348         0.25066         0.14247         0.11010         0.02922         0.02326        76.28781        61.23480           3.04878          10.65872
       6      10.7                0.00185               0.00136         0.19947         0.14702         0.13893         0.10778         0.02922         0.02326        76.28781        61.23480           4.42478          11.21332
       7      12.5                0.00217               0.00192         0.23384         0.20730         0.13528         0.10521         0.02922         0.02326        76.28781        61.23480           3.90625          11.81803
       8      14.1                0.00266               0.00254         0.28749         0.27485         0.13179         0.10278         0.02922         0.02326        76.28781        61.23480           4.31034          10.62050
       9      16.0                0.00250               0.00214         0.27035         0.23108         0.12831         0.10037         0.02922         0.02326        76.28781        61.23480           3.80228           8.25962
      10      17.8                0.00269               0.00243         0.29018         0.26196         0.12483         0.09784         0.02922         0.02326        76.28781        61.23480           4.52489          10.93097
      11      19.5                0.00218               0.00198         0.23518         0.21355         0.12133         0.09523         0.02922         0.02326        76.28781        61.23480           4.40529          11.30330
      12      21.1                0.00259               0.00245         0.27953         0.26460         0.11804         0.09287         0.02922         0.02326        76.28781        61.23480           4.36681          11.35335
      13      22.8                0.00248               0.00226         0.26789         0.24412         0.11474         0.09049         0.02922         0.02326        76.28781        61.23480           4.08163          11.98733
      14      24.4                0.00215               0.00185         0.23196         0.19985         0.11158         0.08824         0.02922         0.02326        76.28781        61.23480           3.42466          11.44267
      15      26.0                0.00353               0.00302         0.38135         0.32583         0.10840         0.08593         0.02922         0.02326        76.28781        61.23480           3.41297          11.50141
      16      27.7                0.00368               0.00289         0.39733         0.31230         0.10509         0.08339         0.02922         0.02326        76.28781        61.23480           3.48432          11.62694
      17      29.4                0.00328               0.00277         0.35471         0.29954         0.10226         0.08139         0.02922         0.02326        76.28781        61.23480           4.42478           9.53071
      18      31.1                0.00258               0.00209         0.27859         0.22561         0.09923         0.07914         0.02922         0.02326        76.28781        61.23480           2.79330          11.77056
      19      32.8                0.00365               0.00313         0.39438         0.33806         0.09616         0.07678         0.02922         0.02326        76.28781        61.23480           2.94118          10.99205
      20      34.5                0.00269               0.00218         0.29059         0.23569         0.09309         0.07437         0.02922         0.02326        76.28781        61.23480           4.00000          10.98432
`Trainer.fit` stopped: `max_epochs=20` reached.
[graph-pes INFO]: Loading best weights from "/Users/john/projects/graph-pes/docs/source/quickstart/graph-pes-results/orb-fine-tune/checkpoints/best.ckpt"
[graph-pes INFO]: Training complete.
[graph-pes INFO]: Testing best model...
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃           Test metric                       DataLoader 0           ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/train/energy_mae              0.16972656548023224        │
│      test/train/energy_rmse             0.23350417613983154        │
│      test/train/forces_mae              0.07444514334201813        │
│      test/train/forces_rmse             0.09575104713439941        │
│ test/train/forces_rmse_batchwise        0.09573473781347275        │
│  test/train/per_atom_energy_mae        0.0015714168548583984       │
│ test/train/per_atom_energy_rmse         0.002161899348720908       │
│      test/train/stress_mae              0.02080981805920601        │
│      test/train/stress_rmse              0.0265880785882473        │
│      test/train/virial_mae               61.88399124145508         │
│      test/train/virial_rmse              92.37788391113281         │
│      test/valid/energy_mae                                         │
│      test/valid/energy_rmse                                        │
│      test/valid/forces_mae                                         │
│      test/valid/forces_rmse                                        │
│ test/valid/forces_rmse_batchwise                                   │
│  test/valid/per_atom_energy_mae                                    │
│ test/valid/per_atom_energy_rmse                                    │
│      test/valid/stress_mae                                         │
│      test/valid/stress_rmse                                        │
│      test/valid/virial_mae                                         │
│      test/valid/virial_rmse                                        │
└──────────────────────────────────┴──────────────────────────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃           Test metric                       DataLoader 1           ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/train/energy_mae                                         │
│      test/train/energy_rmse                                        │
│      test/train/forces_mae                                         │
│      test/train/forces_rmse                                        │
│ test/train/forces_rmse_batchwise                                   │
│  test/train/per_atom_energy_mae                                    │
│ test/train/per_atom_energy_rmse                                    │
│      test/train/stress_mae                                         │
│      test/train/stress_rmse                                        │
│      test/train/virial_mae                                         │
│      test/train/virial_rmse                                        │
│      test/valid/energy_mae              0.23569336533546448        │
│      test/valid/energy_rmse              0.2459844946861267        │
│      test/valid/forces_mae              0.07436826825141907        │
│      test/valid/forces_rmse             0.09522523730993271        │
│ test/valid/forces_rmse_batchwise        0.09307248890399933        │
│  test/valid/per_atom_energy_mae         0.002182579133659601       │
│ test/valid/per_atom_energy_rmse         0.002277533058077097       │
│      test/valid/stress_mae              0.023262472823262215       │
│      test/valid/stress_rmse             0.027134379372000694       │
│      test/valid/virial_mae               61.23480224609375         │
│      test/valid/virial_rmse              89.39186096191406         │
└──────────────────────────────────┴──────────────────────────────────┘
[graph-pes INFO]: Testing complete.
[graph-pes INFO]: Awaiting final Lightning and W&B shutdown...
[26]:
fine_tuned_orb = load_model("graph-pes-results/orb-fine-tune/model.pt").eval()
parity_plot(
    fine_tuned_orb,
    test_set,
    property="energy_per_atom",
    units="eV/atom",
)
plt.title("Fine-tuned orb-d3-xs-v2");
../_images/quickstart_fine-tuning_31_0.png

Again, we aren’t trying to make direct comparisons between these two models here - the instances we have chosen are both small models to show proof of concept.

If you want to explore fine-tuning various foundation models, we strongly recommend that you tune the hyperparameters of the fine-tuning process for each one separately.

Fine-tuning other models

Of course, you can also fine-tune other models that you have trained yourself! In this case, rather than using the foundation model interfaces, use the load_model function to load your model from disk. Remember to pass the auto_fit_reference_energies=true flag to the graph-pes-train command to ensure that this model correctly updates its reference energies:

model:
    +load_model:
        path: path/to/model.pt

fitting:
    auto_fit_reference_energies: true

# ... <-- other config options as normal