Source code for graph_pes.training.loss

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Literal, NamedTuple, Sequence

import torch
import torchmetrics
from torch import Tensor, nn
from torchmetrics import Metric as TorchMetric

from graph_pes.atomic_graph import AtomicGraph, PropertyKey, divide_per_atom
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.utils.misc import uniform_repr
from graph_pes.utils.nn import UniformModuleList

Metric = Callable[[Tensor, Tensor], Tensor]
MetricName = Literal["MAE", "RMSE", "MSE"]


class WeightedLoss:
    def __init__(self, *args, **kwargs):
        # this is now depracated: pass weight directly to the loss object
        raise ImportError(
            "The WeightedLoss class has been removed from graph-pes "
            "as of version 0.0.22. Please now pass loss weights directly "
            "to the loss instances! See the docs for more information: "
            "https://jla-gardner.github.io/graph-pes/fitting/losses.html"
        )


[docs] class Loss(nn.Module, ABC): """ A general base class for all loss functions in ``graph-pes``. Implementations **must** override: * :meth:`forward` to compute the loss value. * :meth:`name` to return the name of the loss function. * :meth:`required_properties` to return the properties that this loss function needs to have available in order to compute its value. Additionally, implementations can optionally override: * :meth:`pre_fit` to perform any necessary operations before training commences. Parameters ---------- weight a scalar multiplier for weighting the value returned by :meth:`forward` as part of a :class:`TotalLoss`. is_per_atom whether this loss returns a value that is normalised per atom, or not. For instance, some metric that acts on ``"forces"`` is naturally per-atom, while a metric that acts on ``"energy"``, or the model etc., is not. Specifying this correctly ensures that the effective batch size is chosen correctly when averaging over batches. """ def __init__(self, weight, is_per_atom: bool = False): super().__init__() self.weight = weight self.is_per_atom = is_per_atom
[docs] @abstractmethod def forward( self, model: GraphPESModel, graph: AtomicGraph, predictions: dict[PropertyKey, torch.Tensor], ) -> torch.Tensor | TorchMetric: r""" Compute the unweighted loss value. Note that only :class:`Loss`\ s that return a tensor can be used for training: we reserve the use of :class:`TorchMetric`\ s for evaluation metrics only. :class:`Loss`\ s can act on any of: Parameters ---------- model The model being trained. graph The graph (usually a batch) the ``model`` was applied to. predictions The predictions from the ``model`` for the given ``graph``. """
@property @abstractmethod def required_properties(self) -> list[PropertyKey]: """The properties that are required by this loss function.""" @property @abstractmethod def name(self) -> str: """The name of this loss function, for logging purposes."""
[docs] def pre_fit(self, training_data: AtomicGraph): """ Perform any necessary operations before training commences. For example, this could be used to pre-compute a standard deviation of some property in the training data, which could then be used in :meth:`forward`. Parameters ---------- training_data The training data to pre-fit this loss function to. """
# add type hints to play nicely with mypy def __call__( self, model: GraphPESModel, graph: AtomicGraph, predictions: dict[PropertyKey, torch.Tensor], ) -> torch.Tensor | TorchMetric: return super().__call__(model, graph, predictions)
[docs] class PropertyLoss(Loss): r""" A :class:`PropertyLoss` instance applies its :class:`Metric` to compare a model's predictions to the true values for a given property of a :class:`~graph_pes.AtomicGraph`. Parameters ---------- property The property to apply the loss metric to. metric The loss metric to use. Defaults to :class:`RMSE`. Examples -------- .. code-block:: python energy_rmse_loss = PropertyLoss("energy", RMSE()) energy_rmse_value = energy_rmse_loss( predictions, # a dict of key (energy/force/etc.) to value graph.properties, ) """ def __init__( self, property: PropertyKey, metric: Metric | MetricName | TorchMetric = "RMSE", weight: float = 1.0, ): super().__init__( weight, is_per_atom=property in ("forces", "local_energies"), ) self.property: PropertyKey = property self.metric = parse_metric(metric) def forward( self, model: GraphPESModel, graph: AtomicGraph, predictions: dict[PropertyKey, torch.Tensor], ) -> torch.Tensor | TorchMetric: """ Computes the loss value. Parameters ---------- predictions The predictions from the ``model`` for the given ``graph``. """ if isinstance(self.metric, torchmetrics.Metric): self.metric.update( predictions[self.property], graph.properties[self.property], ) return self.metric return self.metric( predictions[self.property], graph.properties[self.property], ) @property def name(self) -> str: """Get the name of this loss for logging purposes.""" return f"{self.property}_{_get_metric_name(self.metric)}" @property def required_properties(self) -> list[PropertyKey]: return [self.property] def __repr__(self) -> str: return uniform_repr( self.__class__.__name__, self.property, metric=self.metric, )
class SubLossPair(NamedTuple): loss_value: torch.Tensor weighted_loss_value: torch.Tensor is_per_atom: bool class TotalLossResult(NamedTuple): loss_value: torch.Tensor components: dict[str, SubLossPair]
[docs] class TotalLoss(torch.nn.Module): r""" A lightweight wrapper around a collection of losses. .. math:: \mathcal{L}_{\text{total}} = \sum_i w_i \mathcal{L}_i where :math:`\mathcal{L}_i` is the :math:`i`-th loss and :math:`w_i` is the corresponding weight. ``graph-pes`` models are trained by minimising a :class:`TotalLoss` value. Parameters ---------- losses The collection of losses to aggregate. """ def __init__(self, losses: Sequence[Loss]): super().__init__() self.losses = UniformModuleList(losses) def forward( self, model: GraphPESModel, graph: AtomicGraph, predictions: dict[PropertyKey, torch.Tensor], ) -> TotalLossResult: """ Computes the total loss value. Parameters ---------- predictions The predictions from the model. graph The graph (usually a batch) the ``model`` was applied to. """ total_loss = torch.scalar_tensor(0.0, device=graph.Z.device) components: dict[str, SubLossPair] = {} for loss in self.losses: loss_value = loss(model, graph, predictions) weighted_loss_value = loss_value * loss.weight if not isinstance(loss_value, torch.Tensor): raise ValueError( f"Total losses can only be used alongside metrics that " f"return a tensor. Instead, loss {loss.name} returned " f"{type(loss_value)}." ) total_loss += weighted_loss_value components[loss.name] = SubLossPair( loss_value, weighted_loss_value, loss.is_per_atom, ) return TotalLossResult(total_loss, components) # add type hints to appease mypy def __call__( self, model: GraphPESModel, graph: AtomicGraph, predictions: dict[PropertyKey, torch.Tensor], ) -> TotalLossResult: return super().__call__(model, graph, predictions) def __repr__(self) -> str: return "\n".join( ["TotalLoss:"] + [" ".join(str(loss).split("\n")) for loss in self.losses] )
############################# CUSTOM LOSSES #############################
[docs] class PerAtomEnergyLoss(PropertyLoss): r""" A loss function that evaluates some metric on the total energy normalised by the number of atoms in the structure. .. math:: \mathcal{L} = \text{metric}\left( \bigoplus_i \frac{\hat{E}_i}{N_i}, \bigoplus_i\frac{E_i}{N_i} \right) where :math:`\hat{E}_i` is the predicted energy for structure :math:`i`, :math:`E_i` is the true energy for structure :math:`i`, :math:`N_i` is the number of atoms in structure :math:`i` and :math:`\bigoplus_i` denotes the cocatenation over all structures in the batch. Parameters ---------- metric The loss metric to use. Defaults to :class:`RMSE`. """ def __init__( self, metric: Metric | MetricName | TorchMetric = "RMSE", weight: float = 1.0, ): super().__init__("energy", metric, weight) def forward( self, model: GraphPESModel, graph: AtomicGraph, predictions: dict[PropertyKey, torch.Tensor], ) -> torch.Tensor | TorchMetric: if isinstance(self.metric, torchmetrics.Metric): self.metric.update( divide_per_atom(predictions["energy"], graph), divide_per_atom(graph.properties["energy"], graph), ) return self.metric return self.metric( divide_per_atom(predictions["energy"], graph), divide_per_atom(graph.properties["energy"], graph), ) @property def name(self) -> str: return f"per_atom_energy_{_get_metric_name(self.metric)}"
class ForceRMSE(PropertyLoss): """ Alias for :class:`PropertyLoss` with ``property="forces"`` and ``metric=RMSE``. """ def __init__(self, weight: float = 1.0): super().__init__("forces", RMSE(), weight) ## METRICS ## def parse_metric(metric: Metric | MetricName | TorchMetric | None) -> Metric: if isinstance(metric, str): return {"MAE": MAE(), "RMSE": RMSE(), "MSE": MSE()}[metric] if metric is None: return RMSE() return metric
[docs] class MSE(torch.nn.MSELoss): r""" Mean squared error metric: .. math:: \frac{1}{N} \sum_i^N \left( \hat{P}_i - P_i \right)^2 """
[docs] class RMSE(torch.nn.MSELoss): r""" Root mean squared error metric: .. math:: \sqrt{ \frac{1}{N} \sum_i^N \left( \hat{P}_i - P_i \right)^2 } .. note:: Metrics are computed per-batch in the Lightning Trainer. When aggregating this metric across multiple batches, we therefore get the mean of the per-batch RMSEs, which is not the same as the RMSE between all predictions and targets from all batches. ``graph-pes`` avoids this issue for all RMSE-based metrics logged as ``"{valid|test}/metrics/..._rmse"`` by using a different, ``torchmetrics`` based implementation. """ def forward( self, input: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: return (super().forward(input, target)).sqrt() @property def name(self) -> str: return "rmse_batchwise"
[docs] class MAE(torch.nn.L1Loss): r""" Mean absolute error metric: .. math:: \frac{1}{N} \sum_i^N \left| \hat{P}_i - P_i \right| """
def _get_metric_name(metric: Metric) -> str: if hasattr(metric, "name"): return metric.name # type: ignore if isinstance(metric, torchmetrics.MeanSquaredError): return "mse" if metric.squared else "rmse" # if metric is a function, we want the function's name, otherwise # we want the metric's class name, all lowercased # and without the word "loss" in it return ( getattr( metric, "__name__", metric.__class__.__name__, ) .lower() .replace("loss", "") )