Source code for graph_pes.training.callbacks

from __future__ import annotations

import copy
import time
from abc import ABC
from pathlib import Path
from typing import Any, Literal, Mapping, cast

import pytorch_lightning as pl
import torch
from ase.data import chemical_symbols
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import ProgressBar, StochasticWeightAveraging
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers import WandbLogger as PTLWandbLogger
from typing_extensions import override

from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.models.addition import AdditionModel
from graph_pes.models.components.scaling import LocalEnergiesScaler
from graph_pes.models.offsets import LearnableOffset
from graph_pes.training.utils import VALIDATION_LOSS_KEY
from graph_pes.utils.lammps import deploy_model
from graph_pes.utils.logger import logger
from graph_pes.utils.misc import uniform_repr


[docs] class GraphPESCallback(Callback, ABC): """ A base class for all callbacks that require access to useful information generated by the ``graph-pes-train`` command. """ def __init__(self): self.root: Path = None # type: ignore def _register_root(self, root: Path): # called by us before any training starts self.root = root def get_model(self, pl_module: LightningModule) -> GraphPESModel: return cast(GraphPESModel, pl_module.model) def get_model_on_cpu(self, pl_module: LightningModule) -> GraphPESModel: model = self.get_model(pl_module) model = copy.deepcopy(model) model.to("cpu") return model def __repr__(self) -> str: return f"{self.__class__.__name__}(root={self.root})"
[docs] class DumpModel(GraphPESCallback): """ Dump the model to ``<output_dir>/dumps/model_{epoch}.pt`` at regular intervals. Parameters ---------- every_n_val_checks: int The number of validation epochs between dumps. """ def __init__(self, every_n_val_checks: int = 10): self.every_n_val_checks = every_n_val_checks def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ): if not trainer.is_global_zero: return epoch = trainer.current_epoch if epoch % self.every_n_val_checks != 0: return model_path = self.root / "dumps" / f"model_{epoch}.pt" model_path.parent.mkdir(exist_ok=True) torch.save(self.get_model_on_cpu(pl_module), model_path)
def log_offset(model: GraphPESModel, logger: Logger): if not isinstance(model, AdditionModel): return offsets = [ c for c in model.models.values() if isinstance(c, LearnableOffset) ] if not offsets: return Zs = offsets[0]._offsets._accessed_Zs logger.log_metrics( { f"offset/{chemical_symbols[Z]}": offsets[0]._offsets[Z].item() for Z in Zs } )
[docs] class OffsetLogger(GraphPESCallback): """ Log any learned, per-element offsets of the model at the end of each validation epoch. """ def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ): if not trainer.is_global_zero: return if not trainer.logger: return log_offset(self.get_model(pl_module), trainer.logger)
def log_scales(model: GraphPESModel, logger: Logger, name: str | None = None): if isinstance(model, AdditionModel): for name, child in model.models.items(): log_scales(child, logger, name) return prefix = f"scale/{name}" if name else "scale" scaler = next( (c for c in model.modules() if isinstance(c, LocalEnergiesScaler)), None ) if not scaler: return scaling = scaler.per_element_scaling logger.log_metrics( { f"{prefix}/{chemical_symbols[Z]}": scaling[Z].item() for Z in scaling._accessed_Zs } )
[docs] class ScalesLogger(GraphPESCallback): """ Log any learned, per-element scaling factors of the model at the end of each validation epoch. """ def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ): if not trainer.is_global_zero: return if not trainer.logger: return log_scales(self.get_model(pl_module), trainer.logger)
class EarlyStoppingWithLogging(EarlyStopping, GraphPESCallback): """ Log various information relating to the early stopping process: * number of validation checks since the best validation loss was observed * "distances above" the best validation loss """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.state = { "best_val_loss": float("inf"), "best_val_loss_check": 0, "total_checks": 0, } def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ): super().on_validation_epoch_end(trainer, pl_module) if not trainer.is_global_zero or not trainer.logger: return self.state["total_checks"] += 1 current_loss = trainer.callback_metrics.get(VALIDATION_LOSS_KEY, None) if current_loss is None: return current_loss = current_loss.item() if current_loss < self.state["best_val_loss"]: self.state["best_val_loss"] = current_loss self.state["best_val_loss_check"] = self.state["total_checks"] checks_since_best = ( self.state["total_checks"] - self.state["best_val_loss_check"] ) distance_above_best = current_loss - self.state["best_val_loss"] trainer.logger.log_metrics( { "early_stopping/checks_since_best": checks_since_best, "early_stopping/best_valid_loss": self.state["best_val_loss"], "early_stopping/distance_above_best": distance_above_best, } ) def load_state_dict(self, state_dict: dict): self.state.update(state_dict) def state_dict(self) -> dict: return self.state class SaveBestModel(GraphPESCallback): """ Save the best model to ``<output_dir>/model.pt`` and deploy it to ``<output_dir>/lammps_model.pt``. """ def __init__(self): super().__init__() self.best_val_loss = float("inf") self.try_to_deploy = True def on_validation_end(self, trainer: Trainer, pl_module: LightningModule): if not trainer.is_global_zero or not self.root: return # get current validation loss current_loss = trainer.callback_metrics.get(VALIDATION_LOSS_KEY, None) if current_loss is None: return # if the validation loss has improved, save the model if current_loss < self.best_val_loss: self.best_val_loss = current_loss cpu_model = self.get_model_on_cpu(pl_module) logger.debug(f"New best model: validation loss {current_loss}") # log the final path to the trainer.logger.summary model_path = self.root / "model.pt" lammps_model_path = self.root / "lammps_model.pt" assert trainer.logger is not None trainer.logger.log_hyperparams( { "model_path": model_path, "lammps_model_path": lammps_model_path, } ) torch.save(cpu_model, model_path) logger.debug(f"Model saved to {model_path}") if self.try_to_deploy: try: deploy_model(cpu_model, path=lammps_model_path) logger.debug( f"Deployed model for use with LAMMPS to " f"{lammps_model_path}" ) except Exception as e: logger.warning( f"Failed to deploy model for use with LAMMPS: {e}" ) self.try_to_deploy = False
[docs] class ModelTimer(pl.Callback): def __init__(self): super().__init__() self.tick_ms: float | None = None def start(self): self.tick_ms = time.time_ns() // 1_000_000 def stop( self, pl_module: pl.LightningModule, stage: Literal["train", "valid"] ): assert self.tick_ms is not None duration_ms = max((time.time_ns() // 1_000_000) - self.tick_ms, 1) self.tick_ms = None for name, x in ( ("step_duration_ms", duration_ms), ("its_per_s", 1_000 / duration_ms), ): pl_module.log( f"timer/{name}/{stage}", x, batch_size=1, on_epoch=stage == "valid", on_step=stage == "train", prog_bar=name == "its_per_s", sync_dist=stage == "valid", ) @override def on_train_batch_start(self, *args, **kwargs): self.start() @override def on_train_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs, ): self.stop(pl_module, "train") @override def on_validation_batch_start(self, *args, **kwargs): self.start() @override def on_validation_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs, ): self.stop(pl_module, "valid")
class VerboseSWACallback(StochasticWeightAveraging): @override def on_train_epoch_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule ) -> None: if (not self._initialized) and ( self.swa_start <= trainer.current_epoch <= self.swa_end ): logger.info("SWA: starting SWA") return super().on_train_epoch_start(trainer, pl_module) class LoggedProgressBar(ProgressBar): """ A progress bar that logs all metrics at the end of each validation epoch. """ def __init__(self): super().__init__() self._enabled = True self._start_time = time.time() self._widths: dict[str, int] = {} @override def on_train_start( self, trainer: pl.Trainer, pl_module: pl.LightningModule ): self._start_time = time.time() @override def disable(self): self._enabled = False @override def enable(self): self._enabled = True @override def on_validation_epoch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, ): if not self._enabled or trainer.sanity_checking: return metrics = self.get_metrics(trainer, pl_module) if not self._widths: # first time we're logging things: # calculate the widths of the columns, and print the headers # Split headers with more than 1 slash # over two rows def split_header(h: str) -> list[str]: if "/" in h: parts = h.split("/") if len(parts) > 2: return [f"{parts[0]}/{parts[1]}", "/".join(parts[2:])] return [h] headers = list(metrics.keys()) split_headers = [split_header(h) for h in headers] header_widths = { header: max(len(line) for line in lines) for header, lines in zip(headers, split_headers) } content_widths = { header: len(metrics[header]) for header in headers } self._widths = { header: max(header_widths[header], content_widths[header]) + (6 if header == "time" else 3) for header in headers } # print the headers first_row = [ "" if len(lines) == 1 else lines[0] for lines in split_headers ] second_row = [lines[-1] for lines in split_headers] print( "".join( f"{part:>{self._widths[header]}}" for part, header in zip(first_row, headers) ), flush=True, ) print( "".join( f"{part:>{self._widths[header]}}" for part, header in zip(second_row, headers) ), flush=True, ) # print the values for this epoch print( "".join(f"{v:>{self._widths[k]}}" for k, v in metrics.items()), flush=True, ) @override def get_metrics( # type: ignore self, trainer: pl.Trainer, pl_module: pl.LightningModule ) -> dict[str, str]: def logged_value(v: float | int | Any): return f"{v:.5f}" if isinstance(v, float) else str(v) metrics = {"epoch": f"{trainer.current_epoch + 1:>5}"} super_metrics = super().get_metrics(trainer, pl_module) super_metrics.pop("v_num", None) for k, v in super_metrics.items(): metrics[k] = logged_value(v) # rearrange according to: epoch | time | valid/* | rest... sorted_metrics = {"epoch": metrics.pop("epoch")} sorted_metrics["time"] = f"{time.time() - self._start_time:.1f}" for k in list(metrics): if k.startswith("valid/"): sorted_metrics[k] = metrics.pop(k) sorted_metrics.update(metrics) return sorted_metrics
[docs] class WandbLogger(PTLWandbLogger): """A subclass of WandbLogger that automatically sets the id and save_dir.""" def __init__(self, output_dir: Path, log_epoch: bool, **kwargs): if "id" not in kwargs: kwargs["id"] = output_dir.name if "save_dir" not in kwargs: kwargs["save_dir"] = str(output_dir.parent) super().__init__(**kwargs) self._kwargs = kwargs self._log_epoch = log_epoch def log_metrics( self, metrics: Mapping[str, float], step: int | None = None ): if not self._log_epoch: metrics = {k: v for k, v in metrics.items() if k != "epoch"} return super().log_metrics(metrics, step) def __repr__(self): return uniform_repr(self.__class__.__name__, **self._kwargs)