Source code for graph_pes.training.opt

from __future__ import annotations

import torch

from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.models.offsets import LearnableOffset
from graph_pes.utils.misc import contains_tensor, uniform_repr


[docs] class Optimizer: """ A factory class for delayed instantiation of :class:`torch.optim.Optimizer` objects. The generated optimizer splits the parameters of the model into two groups: - "non-decayable" parameters, which are all parameters returned by the :meth:`~graph_pes.GraphPESModel.non_decayable_parameters` method of the model. - "normal" parameters, corresponding to the remaining model parameters. Unsurprisingly, any specified weight decay is applied only to the normal model parameters. As an example, per-element energy offsets parameters of :class:`~graph_pes.models.offsets.LearnableOffset` models represent the arbitrary zero points of energies for different elements: it doesn't make sense to push these towards zero during training. .. note:: We use delayed instantiation of optimizers when configuring our training runs to allow for arbitrary changes to the model and its parameters during the :class:`~graph_pes.GraphPESModel.pre_fit_all_components` method. Parameters ---------- name The name of the :class:`torch.optim.Optimizer` class to use, e.g. ``"Adam"`` or ``"SGD"``. Alternatively, provide the type of any subclass of :class:`torch.optim.Optimizer`. **kwargs Additional keyword arguments to pass to the specified optimizer's constructor. Examples -------- Pass a named optimiser: >>> from graph_pes.training.opt import Optimizer >>> optimizer_factory = Optimizer("AdamW", lr=1e-3) >>> optimizer_instance = optimizer_factory(model) Or pass the optimiser class directly: >>> from torch.optim import SGD >>> optimizer_factory = Optimizer(SGD, lr=1e-3) >>> optimizer_instance = optimizer_factory(model) Psuedo-code excerpt from ``graph-pes-train`` logic: >>> from graph_pes.training.opt import Optimizer >>> from graph_pes.models import LennardJones >>> ... >>> optimizer_factory = Optimizer("AdamW", lr=1e-3) >>> model = LennardJones() >>> model.pre_fit(train_loader) >>> optimizer_instance = optimizer_factory(model) """ def __init__( self, name: str | type[torch.optim.Optimizer], **kwargs, ): self.kwargs = kwargs if isinstance(name, str): optimizer_class: type[torch.optim.Optimizer] = getattr( torch.optim, name, None ) # type: ignore if optimizer_class is not None: name = optimizer_class else: raise ValueError( f"Could not find optimizer {name}. " "Please provide the common PyTorch name, or a fully " "qualified class name." ) else: optimizer_class = name dummy_model = LearnableOffset() dummy_opt = optimizer_class(dummy_model.parameters(), **kwargs) if not isinstance(dummy_opt, torch.optim.Optimizer): raise ValueError( "Expected the returned optimizer to be an instance of " "torch.optim.Optimizer, but got " f"{type(dummy_opt)}" ) self.optimizer_class = optimizer_class def __call__(self, model: GraphPESModel) -> torch.optim.Optimizer: """ Create an instance of the specified optimizer class, with the correct parameter groups for the model. """ all_params = list(set(model.parameters())) non_decayable_params = list(set(model.non_decayable_parameters())) param_groups = [ { "name": "non-decayable", "params": non_decayable_params, "weight_decay": 0.0, }, { "name": "normal", "params": [ p for p in all_params if not contains_tensor(non_decayable_params, p) ], }, ] # remove empty groups param_groups = [pg for pg in param_groups if pg["params"]] return self.optimizer_class(param_groups, **self.kwargs) def __repr__(self): return uniform_repr( self.__class__.__name__, name=self.optimizer_class.__name__, **self.kwargs, )
[docs] class LRScheduler: """ A factory class for delayed instantiation of :class:`torch.optim.lr_scheduler.LRScheduler` objects. Parameters ---------- name The name of the :class:`torch.optim.lr_scheduler.LRScheduler` class to use, e.g. ``"ReduceLROnPlateau"``. Alternatively, provide any subclass of :class:`torch.optim.lr_scheduler.LRScheduler`. **kwargs Additional keyword arguments to pass to the specified scheduler's constructor. Examples -------- >>> from graph_pes.training.opt import LRScheduler >>> ... >>> scheduler_factory = LRScheduler( ... "LambdaLR", lr_lambda=lambda epoch: 0.95 ** epoch ... ) >>> scheduler_instance = scheduler_factory(optimizer) """ def __init__( self, name: str | type[torch.optim.lr_scheduler.LRScheduler], **kwargs, ): self.kwargs = kwargs if isinstance(name, str): scheduler_class: type[torch.optim.lr_scheduler.LRScheduler] = ( getattr(torch.optim.lr_scheduler, name, None) ) # type: ignore if scheduler_class is not None: name = scheduler_class else: raise ValueError( f"Could not find scheduler {name}. " "Please provide the common PyTorch name, or a fully " "qualified class name." ) else: scheduler_class = name dummy_opt = torch.optim.Adam(LearnableOffset().parameters()) dummy_scheduler = scheduler_class(dummy_opt, **kwargs) if not isinstance( dummy_scheduler, ( torch.optim.lr_scheduler.LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau, ), ): raise ValueError( "Expected the returned scheduler to be an instance of " "torch.optim.lr_scheduler.LRScheduler, but got " f"{type(dummy_scheduler)}" ) self.scheduler_class = scheduler_class def __call__( self, optimizer: torch.optim.Optimizer ) -> torch.optim.lr_scheduler.LRScheduler: return self.scheduler_class(optimizer, **self.kwargs) def __repr__(self): return uniform_repr( self.__class__.__name__, name=self.scheduler_class.__name__, **self.kwargs, )