Losses¶
In graph-pes
, we distinguish between metrics and losses:
A
Loss
is some function that takes a model, a batch of graphs, and some predictions, and returns a scalar value measuring something that training should seek to minimise. This could be a prediction error, a model weight penalty, or something else.A
Metric
is some function that takes two tensors and returns a scalar value measuring the discrepancy between them.
Losses¶
- class graph_pes.training.loss.Loss(weight)[source]¶
-
A general base class for all loss functions in
graph-pes
.Implementations must override:
forward()
to compute the loss value.name()
to return the name of the loss function.required_properties()
to return the properties that this loss function needs to have available in order to compute its value.
Additionally, implementations can optionally override:
pre_fit()
to perform any necessary operations before training commences.
- Parameters:
weight – A scalar multiplier for weighting the value returned by
forward()
as part of aTotalLoss
.
- abstract forward(model, graph, predictions)[source]¶
Compute the unweighted loss value.
Loss
s can act on any of:- Parameters:
model (GraphPESModel) – The model being trained.
graph (AtomicGraph) – The graph (usually a batch) the
model
was applied to.predictions (dict[Literal['local_energies', 'forces', 'energy', 'stress', 'virial'], ~torch.Tensor]) – The predictions from the
model
for the givengraph
.
- Return type:
- abstract property required_properties: list[Literal['local_energies', 'forces', 'energy', 'stress', 'virial']]¶
The properties that are required by this loss function.
- pre_fit(training_data)[source]¶
Perform any necessary operations before training commences.
For example, this could be used to pre-compute a standard deviation of some property in the training data, which could then be used in
forward()
.- Parameters:
training_data (AtomicGraph) – The training data to pre-fit this loss function to.
- class graph_pes.training.loss.PropertyLoss(property, metric='RMSE', weight=1.0)[source]¶
Bases:
Loss
A
PropertyLoss
instance applies itsMetric
to compare a model’s predictions to the true values for a given property of aAtomicGraph
.- Parameters:
property (PropertyKey) – The property to apply the loss metric to.
metric (Metric | MetricName) – The loss metric to use. Defaults to
RMSE
.
Examples
energy_rmse_loss = PropertyLoss("energy", RMSE()) energy_rmse_value = energy_rmse_loss( predictions, # a dict of key (energy/force/etc.) to value graph.properties, )
- class graph_pes.training.loss.PerAtomEnergyLoss(metric='RMSE', weight=1.0)[source]¶
A loss function that evaluates some metric on the total energy normalised by the number of atoms in the structure.
\[\mathcal{L} = \text{metric}\left( \bigoplus_i \frac{\hat{E}_i}{N_i}, \bigoplus_i\frac{E_i}{N_i} \right)\]where \(\hat{E}_i\) is the predicted energy for structure \(i\), \(E_i\) is the true energy for structure \(i\), \(N_i\) is the number of atoms in structure \(i\) and \(\bigoplus_i\) denotes the cocatenation over all structures in the batch.
- Parameters:
metric (Metric | MetricName) – The loss metric to use. Defaults to
RMSE
.
Metrics¶
- class graph_pes.training.loss.Metric¶
A type alias for any function that takes two input tensors and returns some scalar measure of the discrepancy between them.
Metric = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
- class graph_pes.training.loss.RMSE[source]¶
Root mean squared error metric:
\[\sqrt{ \frac{1}{N} \sum_i^N \left( \hat{P}_i - P_i \right)^2 }\]
Helpers¶
- class graph_pes.training.loss.TotalLoss(losses)[source]¶
Bases:
Module
A lightweight wrapper around a collection of losses.
\[\mathcal{L}_{\text{total}} = \sum_i w_i \mathcal{L}_i\]where \(\mathcal{L}_i\) is the \(i\)-th loss and \(w_i\) is the corresponding weight.
graph-pes
models are trained by minimising aTotalLoss
value.- Parameters:
losses (Sequence[Loss]) – The collection of losses to aggregate.
- class graph_pes.training.loss.MetricName¶
A type alias for a
Literal["RMSE", "MAE", "MSE"]
.