Analysis

graph-pes provides a number of utilities for analysing the models:

graph_pes.utils.analysis.parity_plot(
model,
structures,
property='energy',
transform=None,
units=None,
ax=None,
batch_size=5,
**scatter_kwargs,
)[source]

A nicely formatted parity plot of model predictions vs ground truth for the given property.

Parameters:
  • model (GraphPESModel | GraphPESCalculator) – The model to for generating predictions.

  • structures (Iterable[AtomicGraph] | Iterable[ase.Atoms]) – The structures to make predictions on.

  • property (PropertyKey) – The property to plot, e.g. "energy".

  • transform (Transform | None) – The transform to apply to the predictions and labels before plotting. If not provided, no transform is applied.

  • units (str | None) – The units of the property, for labelling the axes. If not provided, no units are used.

  • ax (matplotlib.axes.Axes | None) – The axes to plot on. If not provided, the current axes are used.

  • batch_size (int) – The size of the batch to use for making predictions.

  • scatter_kwargs – Keyword arguments to pass to plt.scatter.

Examples

Default settings (no units, transforms or custom scatter keywords):

parity_plot(model, train, "energy")
../_images/Cu-LJ-default-parity.svg

Custom settings, as seen in this example notebook:

from graph_pes.atomic_graph import DividePerAtom
from graph_pes.util import Keys

for name, data, colour in zip(
    ["Train", "Test"],
    [train, test],
    ["royalblue", "crimson"],
):
    parity_plot(
        model,
        data,
        "energy",
        transform=DividePerAtom(),
        units="eV / atom",
        label=name,
        c=colour,
    )

plt.legend(loc="upper left", fancybox=False);
../_images/parity-plot.svg
graph_pes.utils.analysis.dimer_curve(
model,
system,
units=None,
set_to_zero=True,
rmin=0.9,
rmax=None,
ax=None,
auto_lim=True,
**plot_kwargs,
)[source]

A nicely formatted dimer curve plot for the given system.

Parameters:
  • model (GraphPESModel) – The model for generating predictions.

  • system (str) – The dimer system. Should be one of: a single element, e.g. "Cu", or a pair of elements, e.g. "CuO".

  • units (str | None) – The units of the energy, for labelling the axes. If not provided, no units are used.

  • set_to_zero (bool) – Whether to set the energy of the dimer at rmax to be zero.

  • rmin (float) – The minimum seperation to consider.

  • rmax (float | None) – The maximum seperation to consider.

  • ax (matplotlib.axes.Axes | None) – The axes to plot on. If not provided, the current axes are used.

  • plot_kwargs – Keyword arguments to pass to plt.plot.

Return type:

matplotlib.lines.Line2D

Examples

from graph_pes.utils.analysis import dimer_curve
from graph_pes.models import LennardJones

dimer_curve(LennardJones(sigma=1.3, epsilon=0.5), system="OH", units="eV")
../_images/dimer-curve.svg
class graph_pes.utils.analysis.Transform

Alias for Callable[[Tensor, AtomicGraph], Tensor].

Transforms map a property, \(x\), to a target property, \(y\), conditioned on an AtomicGraph, \(\mathcal{G}\):

\[T: (x; \mathcal{G}) \mapsto y\]