from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Callable, Literal, NamedTuple, Sequence
import torch
from torch import Tensor, nn
from graph_pes.atomic_graph import AtomicGraph, PropertyKey, divide_per_atom
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.utils.misc import uniform_repr
from graph_pes.utils.nn import UniformModuleList
Metric = Callable[[Tensor, Tensor], Tensor]
MetricName = Literal["MAE", "RMSE", "MSE"]
class WeightedLoss:
def __init__(self, *args, **kwargs):
# this is now depracated: pass weight directly to the loss object
raise ImportError(
"The WeightedLoss class has been removed from graph-pes "
"as of version 0.0.22. Please now pass loss weights directly "
"to the loss instances! See the docs for more information: "
"https://jla-gardner.github.io/graph-pes/fitting/losses.html"
)
[docs]
class Loss(nn.Module, ABC):
"""
A general base class for all loss functions in ``graph-pes``.
Implementations **must** override:
* :meth:`forward` to compute the loss value.
* :meth:`name` to return the name of the loss function.
* :meth:`required_properties` to return the properties that this loss
function needs to have available in order to compute its value.
Additionally, implementations can optionally override:
* :meth:`pre_fit` to perform any necessary operations before training
commences.
Parameters
----------
weight
A scalar multiplier for weighting the value returned by
:meth:`forward` as part of a :class:`TotalLoss`.
"""
def __init__(self, weight):
super().__init__()
self.weight = weight
[docs]
@abstractmethod
def forward(
self,
model: GraphPESModel,
graph: AtomicGraph,
predictions: dict[PropertyKey, torch.Tensor],
) -> torch.Tensor:
r"""
Compute the unweighted loss value.
:class:`Loss`\ s can act on any of:
Parameters
----------
model
The model being trained.
graph
The graph (usually a batch) the ``model`` was applied to.
predictions
The predictions from the ``model`` for the given ``graph``.
"""
@property
@abstractmethod
def required_properties(self) -> list[PropertyKey]:
"""The properties that are required by this loss function."""
@property
@abstractmethod
def name(self) -> str:
"""The name of this loss function, for logging purposes."""
[docs]
def pre_fit(self, training_data: AtomicGraph):
"""
Perform any necessary operations before training commences.
For example, this could be used to pre-compute a standard deviation
of some property in the training data, which could then be used in
:meth:`forward`.
Parameters
----------
training_data
The training data to pre-fit this loss function to.
"""
# add type hints to play nicely with mypy
def __call__(
self,
model: GraphPESModel,
graph: AtomicGraph,
predictions: dict[PropertyKey, torch.Tensor],
) -> torch.Tensor:
return super().__call__(model, graph, predictions)
[docs]
class PropertyLoss(Loss):
r"""
A :class:`PropertyLoss` instance applies its :class:`Metric` to compare a
model's predictions to the true values for a given property of a
:class:`~graph_pes.AtomicGraph`.
Parameters
----------
property
The property to apply the loss metric to.
metric
The loss metric to use. Defaults to :class:`RMSE`.
Examples
--------
.. code-block:: python
energy_rmse_loss = PropertyLoss("energy", RMSE())
energy_rmse_value = energy_rmse_loss(
predictions, # a dict of key (energy/force/etc.) to value
graph.properties,
)
"""
def __init__(
self,
property: PropertyKey,
metric: Metric | MetricName = "RMSE",
weight: float = 1.0,
):
super().__init__(weight)
self.property: PropertyKey = property
self.metric = parse_metric(metric)
def forward(
self,
model: GraphPESModel,
graph: AtomicGraph,
predictions: dict[PropertyKey, torch.Tensor],
) -> torch.Tensor:
"""
Computes the loss value.
Parameters
----------
predictions
The predictions from the ``model`` for the given ``graph``.
"""
return self.metric(
predictions[self.property],
graph.properties[self.property],
)
@property
def name(self) -> str:
"""Get the name of this loss for logging purposes."""
return f"{self.property}_{_get_metric_name(self.metric)}"
@property
def required_properties(self) -> list[PropertyKey]:
return [self.property]
def __repr__(self) -> str:
return uniform_repr(
self.__class__.__name__,
self.property,
metric=self.metric,
)
class SubLossPair(NamedTuple):
loss_value: torch.Tensor
weighted_loss_value: torch.Tensor
class TotalLossResult(NamedTuple):
loss_value: torch.Tensor
components: dict[str, SubLossPair]
[docs]
class TotalLoss(torch.nn.Module):
r"""
A lightweight wrapper around a collection of losses.
.. math::
\mathcal{L}_{\text{total}} = \sum_i w_i \mathcal{L}_i
where :math:`\mathcal{L}_i` is the :math:`i`-th loss and :math:`w_i` is the
corresponding weight.
``graph-pes`` models are trained by minimising a :class:`TotalLoss` value.
Parameters
----------
losses
The collection of losses to aggregate.
"""
def __init__(self, losses: Sequence[Loss]):
super().__init__()
self.losses = UniformModuleList(losses)
def forward(
self,
model: GraphPESModel,
graph: AtomicGraph,
predictions: dict[PropertyKey, torch.Tensor],
) -> TotalLossResult:
"""
Computes the total loss value.
Parameters
----------
predictions
The predictions from the model.
graph
The graph (usually a batch) the ``model`` was applied to.
"""
total_loss = torch.scalar_tensor(0.0, device=graph.Z.device)
components: dict[str, SubLossPair] = {}
for loss in self.losses:
loss_value = loss(model, graph, predictions)
weighted_loss_value = loss_value * loss.weight
total_loss += weighted_loss_value
components[loss.name] = SubLossPair(loss_value, weighted_loss_value)
return TotalLossResult(total_loss, components)
# add type hints to appease mypy
def __call__(
self,
model: GraphPESModel,
graph: AtomicGraph,
predictions: dict[PropertyKey, torch.Tensor],
) -> TotalLossResult:
return super().__call__(model, graph, predictions)
def __repr__(self) -> str:
return "\n".join(
["TotalLoss:"]
+ [" ".join(str(loss).split("\n")) for loss in self.losses]
)
############################# CUSTOM LOSSES #############################
[docs]
class PerAtomEnergyLoss(PropertyLoss):
r"""
A loss function that evaluates some metric on the total energy normalised
by the number of atoms in the structure.
.. math::
\mathcal{L} = \text{metric}\left(
\bigoplus_i \frac{\hat{E}_i}{N_i}, \bigoplus_i\frac{E_i}{N_i} \right)
where :math:`\hat{E}_i` is the predicted energy for structure :math:`i`,
:math:`E_i` is the true energy for structure :math:`i`, :math:`N_i`
is the number of atoms in structure :math:`i` and :math:`\bigoplus_i`
denotes the cocatenation over all structures in the batch.
Parameters
----------
metric
The loss metric to use. Defaults to :class:`RMSE`.
"""
def __init__(
self,
metric: Metric | MetricName = "RMSE",
weight: float = 1.0,
):
super().__init__("energy", metric, weight)
def forward(
self,
model: GraphPESModel,
graph: AtomicGraph,
predictions: dict[PropertyKey, torch.Tensor],
) -> torch.Tensor:
return self.metric(
divide_per_atom(predictions["energy"], graph),
divide_per_atom(graph.properties["energy"], graph),
)
@property
def name(self) -> str:
return f"per_atom_energy_{_get_metric_name(self.metric)}"
class ForceRMSE(PropertyLoss):
"""
Alias for :class:`PropertyLoss` with ``property="forces"`` and
``metric=RMSE``.
"""
def __init__(self, weight: float = 1.0):
super().__init__("forces", RMSE(), weight)
## METRICS ##
def parse_metric(metric: Metric | MetricName | None) -> Metric:
if isinstance(metric, str):
return {"MAE": MAE(), "RMSE": RMSE(), "MSE": MSE()}[metric]
if metric is None:
return RMSE()
return metric
[docs]
class MSE(torch.nn.MSELoss):
r"""
Mean squared error metric:
.. math::
\frac{1}{N} \sum_i^N \left( \hat{P}_i - P_i \right)^2
"""
[docs]
class RMSE(torch.nn.MSELoss):
r"""
Root mean squared error metric:
.. math::
\sqrt{ \frac{1}{N} \sum_i^N \left( \hat{P}_i - P_i \right)^2 }
"""
def forward(
self, input: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
return (super().forward(input, target)).sqrt()
[docs]
class MAE(torch.nn.L1Loss):
r"""
Mean absolute error metric:
.. math::
\frac{1}{N} \sum_i^N \left| \hat{P}_i - P_i \right|
"""
def _get_metric_name(metric: Metric) -> str:
# if metric is a function, we want the function's name, otherwise
# we want the metric's class name, all lowercased
# and without the word "loss" in it
return (
getattr(
metric,
"__name__",
metric.__class__.__name__,
)
.lower()
.replace("loss", "")
)