Losses

In graph-pes, we distinguish between metrics and losses:

  • A Loss is some function that takes a model, a batch of graphs, and some predictions, and returns a scalar value measuring something that training should seek to minimise. This could be a prediction error, a model weight penalty, or something else.

  • A Metric is some function that takes two tensors and returns a scalar value measuring the discrepancy between them.

Losses

class graph_pes.training.loss.Loss(weight, is_per_atom=False)[source]

Bases: Module, ABC

A general base class for all loss functions in graph-pes.

Implementations must override:

  • forward() to compute the loss value.

  • name() to return the name of the loss function.

  • required_properties() to return the properties that this loss function needs to have available in order to compute its value.

Additionally, implementations can optionally override:

  • pre_fit() to perform any necessary operations before training commences.

Parameters:
  • weight – a scalar multiplier for weighting the value returned by forward() as part of a TotalLoss.

  • is_per_atom (bool) – whether this loss returns a value that is normalised per atom, or not. For instance, some metric that acts on "forces" is naturally per-atom, while a metric that acts on "energy", or the model etc., is not. Specifying this correctly ensures that the effective batch size is chosen correctly when averaging over batches.

abstract forward(model, graph, predictions)[source]

Compute the unweighted loss value.

Note that only Losss that return a tensor can be used for training: we reserve the use of TorchMetrics for evaluation metrics only.

Losss can act on any of:

Parameters:
Return type:

torch.Tensor | TorchMetric

abstract property required_properties: list[Literal['local_energies', 'forces', 'energy', 'stress', 'virial']]

The properties that are required by this loss function.

abstract property name: str

The name of this loss function, for logging purposes.

pre_fit(training_data)[source]

Perform any necessary operations before training commences.

For example, this could be used to pre-compute a standard deviation of some property in the training data, which could then be used in forward().

Parameters:

training_data (AtomicGraph) – The training data to pre-fit this loss function to.

class graph_pes.training.loss.PropertyLoss(property, metric='RMSE', weight=1.0)[source]

Bases: Loss

A PropertyLoss instance applies its Metric to compare a model’s predictions to the true values for a given property of a AtomicGraph.

Parameters:
  • property (PropertyKey) – The property to apply the loss metric to.

  • metric (Metric | MetricName | TorchMetric) – The loss metric to use. Defaults to RMSE.

Examples

energy_rmse_loss = PropertyLoss("energy", RMSE())
energy_rmse_value = energy_rmse_loss(
    predictions,  # a dict of key (energy/force/etc.) to value
    graph.properties,
)
class graph_pes.training.loss.PerAtomEnergyLoss(metric='RMSE', weight=1.0)[source]

A loss function that evaluates some metric on the total energy normalised by the number of atoms in the structure.

\[\mathcal{L} = \text{metric}\left( \bigoplus_i \frac{\hat{E}_i}{N_i}, \bigoplus_i\frac{E_i}{N_i} \right)\]

where \(\hat{E}_i\) is the predicted energy for structure \(i\), \(E_i\) is the true energy for structure \(i\), \(N_i\) is the number of atoms in structure \(i\) and \(\bigoplus_i\) denotes the cocatenation over all structures in the batch.

Parameters:

metric (Metric | MetricName | TorchMetric) – The loss metric to use. Defaults to RMSE.

Metrics

class graph_pes.training.loss.Metric

A type alias for any function that takes two input tensors and returns some scalar measure of the discrepancy between them.

Metric = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
class graph_pes.training.loss.RMSE[source]

Root mean squared error metric:

\[\sqrt{ \frac{1}{N} \sum_i^N \left( \hat{P}_i - P_i \right)^2 }\]

Note

Metrics are computed per-batch in the Lightning Trainer. When aggregating this metric across multiple batches, we therefore get the mean of the per-batch RMSEs, which is not the same as the RMSE between all predictions and targets from all batches.

graph-pes avoids this issue for all RMSE-based metrics logged as "{valid|test}/metrics/..._rmse" by using a different, torchmetrics based implementation.

class graph_pes.training.loss.MAE[source]

Mean absolute error metric:

\[\frac{1}{N} \sum_i^N \left| \hat{P}_i - P_i \right|\]
class graph_pes.training.loss.MSE[source]

Mean squared error metric:

\[\frac{1}{N} \sum_i^N \left( \hat{P}_i - P_i \right)^2\]

Helpers

class graph_pes.training.loss.TotalLoss(losses)[source]

Bases: Module

A lightweight wrapper around a collection of losses.

\[\mathcal{L}_{\text{total}} = \sum_i w_i \mathcal{L}_i\]

where \(\mathcal{L}_i\) is the \(i\)-th loss and \(w_i\) is the corresponding weight.

graph-pes models are trained by minimising a TotalLoss value.

Parameters:

losses (Sequence[Loss]) – The collection of losses to aggregate.

class graph_pes.training.loss.MetricName

A type alias for a Literal["RMSE", "MAE", "MSE"].