from __future__ import annotations
import copy
import time
from abc import ABC
from pathlib import Path
from typing import Any, Literal, Mapping, cast
import pytorch_lightning as pl
import torch
from ase.data import chemical_symbols
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import ProgressBar, StochasticWeightAveraging
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers import WandbLogger as PTLWandbLogger
from typing_extensions import override
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.models.addition import AdditionModel
from graph_pes.models.components.scaling import LocalEnergiesScaler
from graph_pes.models.offsets import LearnableOffset
from graph_pes.training.utils import VALIDATION_LOSS_KEY
from graph_pes.utils.lammps import deploy_model
from graph_pes.utils.logger import logger
from graph_pes.utils.misc import uniform_repr
[docs]
class GraphPESCallback(Callback, ABC):
"""
A base class for all callbacks that require access to useful
information generated by the ``graph-pes-train`` command.
"""
def __init__(self):
self.root: Path = None # type: ignore
def _register_root(self, root: Path):
# called by us before any training starts
self.root = root
def get_model(self, pl_module: LightningModule) -> GraphPESModel:
return cast(GraphPESModel, pl_module.model)
def get_model_on_cpu(self, pl_module: LightningModule) -> GraphPESModel:
model = self.get_model(pl_module)
model = copy.deepcopy(model)
model.to("cpu")
return model
def __repr__(self) -> str:
return f"{self.__class__.__name__}(root={self.root})"
[docs]
class DumpModel(GraphPESCallback):
"""
Dump the model to ``<output_dir>/dumps/model_{epoch}.pt``
at regular intervals.
Parameters
----------
every_n_val_checks: int
The number of validation epochs between dumps.
"""
def __init__(self, every_n_val_checks: int = 10):
self.every_n_val_checks = every_n_val_checks
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
):
if not trainer.is_global_zero:
return
epoch = trainer.current_epoch
if epoch % self.every_n_val_checks != 0:
return
model_path = self.root / "dumps" / f"model_{epoch}.pt"
model_path.parent.mkdir(exist_ok=True)
torch.save(self.get_model_on_cpu(pl_module), model_path)
def log_offset(model: GraphPESModel, logger: Logger):
if not isinstance(model, AdditionModel):
return
offsets = [
c for c in model.models.values() if isinstance(c, LearnableOffset)
]
if not offsets:
return
Zs = offsets[0]._offsets._accessed_Zs
logger.log_metrics(
{
f"offset/{chemical_symbols[Z]}": offsets[0]._offsets[Z].item()
for Z in Zs
}
)
[docs]
class OffsetLogger(GraphPESCallback):
"""
Log any learned, per-element offsets of the model at the
end of each validation epoch.
"""
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
):
if not trainer.is_global_zero:
return
if not trainer.logger:
return
log_offset(self.get_model(pl_module), trainer.logger)
def log_scales(model: GraphPESModel, logger: Logger, name: str | None = None):
if isinstance(model, AdditionModel):
for name, child in model.models.items():
log_scales(child, logger, name)
return
prefix = f"scale/{name}" if name else "scale"
scaler = next(
(c for c in model.modules() if isinstance(c, LocalEnergiesScaler)), None
)
if not scaler:
return
scaling = scaler.per_element_scaling
logger.log_metrics(
{
f"{prefix}/{chemical_symbols[Z]}": scaling[Z].item()
for Z in scaling._accessed_Zs
}
)
[docs]
class ScalesLogger(GraphPESCallback):
"""
Log any learned, per-element scaling factors of the model at the
end of each validation epoch.
"""
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
):
if not trainer.is_global_zero:
return
if not trainer.logger:
return
log_scales(self.get_model(pl_module), trainer.logger)
class EarlyStoppingWithLogging(EarlyStopping, GraphPESCallback):
"""
Log various information relating to the early stopping process:
* number of validation checks since the best validation loss was observed
* "distances above" the best validation loss
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.state = {
"best_val_loss": float("inf"),
"best_val_loss_check": 0,
"total_checks": 0,
}
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
):
super().on_validation_epoch_end(trainer, pl_module)
if not trainer.is_global_zero or not trainer.logger:
return
self.state["total_checks"] += 1
current_loss = trainer.callback_metrics.get(VALIDATION_LOSS_KEY, None)
if current_loss is None:
return
current_loss = current_loss.item()
if current_loss < self.state["best_val_loss"]:
self.state["best_val_loss"] = current_loss
self.state["best_val_loss_check"] = self.state["total_checks"]
checks_since_best = (
self.state["total_checks"] - self.state["best_val_loss_check"]
)
distance_above_best = current_loss - self.state["best_val_loss"]
trainer.logger.log_metrics(
{
"early_stopping/checks_since_best": checks_since_best,
"early_stopping/best_valid_loss": self.state["best_val_loss"],
"early_stopping/distance_above_best": distance_above_best,
}
)
def load_state_dict(self, state_dict: dict):
self.state.update(state_dict)
def state_dict(self) -> dict:
return self.state
class SaveBestModel(GraphPESCallback):
"""
Save the best model to ``<output_dir>/model.pt`` and deploy it to
``<output_dir>/lammps_model.pt``.
"""
def __init__(self):
super().__init__()
self.best_val_loss = float("inf")
self.try_to_deploy = True
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
if not trainer.is_global_zero or not self.root:
return
# get current validation loss
current_loss = trainer.callback_metrics.get(VALIDATION_LOSS_KEY, None)
if current_loss is None:
return
# if the validation loss has improved, save the model
if current_loss < self.best_val_loss:
self.best_val_loss = current_loss
cpu_model = self.get_model_on_cpu(pl_module)
logger.debug(f"New best model: validation loss {current_loss}")
# log the final path to the trainer.logger.summary
model_path = self.root / "model.pt"
lammps_model_path = self.root / "lammps_model.pt"
assert trainer.logger is not None
trainer.logger.log_hyperparams(
{
"model_path": model_path,
"lammps_model_path": lammps_model_path,
}
)
torch.save(cpu_model, model_path)
logger.debug(f"Model saved to {model_path}")
if self.try_to_deploy:
try:
deploy_model(cpu_model, path=lammps_model_path)
logger.debug(
f"Deployed model for use with LAMMPS to "
f"{lammps_model_path}"
)
except Exception as e:
logger.warning(
f"Failed to deploy model for use with LAMMPS: {e}"
)
self.try_to_deploy = False
[docs]
class ModelTimer(pl.Callback):
def __init__(self):
super().__init__()
self.tick_ms: float | None = None
def start(self):
self.tick_ms = time.time_ns() // 1_000_000
def stop(
self, pl_module: pl.LightningModule, stage: Literal["train", "valid"]
):
assert self.tick_ms is not None
duration_ms = max((time.time_ns() // 1_000_000) - self.tick_ms, 1)
self.tick_ms = None
for name, x in (
("step_duration_ms", duration_ms),
("its_per_s", 1_000 / duration_ms),
):
pl_module.log(
f"timer/{name}/{stage}",
x,
batch_size=1,
on_epoch=stage == "valid",
on_step=stage == "train",
prog_bar=name == "its_per_s",
sync_dist=stage == "valid",
)
@override
def on_train_batch_start(self, *args, **kwargs):
self.start()
@override
def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
*args,
**kwargs,
):
self.stop(pl_module, "train")
@override
def on_validation_batch_start(self, *args, **kwargs):
self.start()
@override
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
*args,
**kwargs,
):
self.stop(pl_module, "valid")
class VerboseSWACallback(StochasticWeightAveraging):
@override
def on_train_epoch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
if (not self._initialized) and (
self.swa_start <= trainer.current_epoch <= self.swa_end
):
logger.info("SWA: starting SWA")
return super().on_train_epoch_start(trainer, pl_module)
class LoggedProgressBar(ProgressBar):
"""
A progress bar that logs all metrics at the end of each validation epoch.
"""
def __init__(self):
super().__init__()
self._enabled = True
self._start_time = time.time()
self._widths: dict[str, int] = {}
@override
def on_train_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
self._start_time = time.time()
@override
def disable(self):
self._enabled = False
@override
def enable(self):
self._enabled = True
@override
def on_validation_epoch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
):
if not self._enabled or trainer.sanity_checking:
return
metrics = self.get_metrics(trainer, pl_module)
if not self._widths:
# first time we're logging things:
# calculate the widths of the columns, and print the headers
# Split headers with more than 1 slash
# over two rows
def split_header(h: str) -> list[str]:
if "/" in h:
parts = h.split("/")
if len(parts) > 2:
return [f"{parts[0]}/{parts[1]}", "/".join(parts[2:])]
return [h]
headers = list(metrics.keys())
split_headers = [split_header(h) for h in headers]
header_widths = {
header: max(len(line) for line in lines)
for header, lines in zip(headers, split_headers)
}
content_widths = {
header: len(metrics[header]) for header in headers
}
self._widths = {
header: max(header_widths[header], content_widths[header])
+ (6 if header == "time" else 3)
for header in headers
}
# print the headers
first_row = [
"" if len(lines) == 1 else lines[0] for lines in split_headers
]
second_row = [lines[-1] for lines in split_headers]
print(
"".join(
f"{part:>{self._widths[header]}}"
for part, header in zip(first_row, headers)
),
flush=True,
)
print(
"".join(
f"{part:>{self._widths[header]}}"
for part, header in zip(second_row, headers)
),
flush=True,
)
# print the values for this epoch
print(
"".join(f"{v:>{self._widths[k]}}" for k, v in metrics.items()),
flush=True,
)
@override
def get_metrics( # type: ignore
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> dict[str, str]:
def logged_value(v: float | int | Any):
return f"{v:.5f}" if isinstance(v, float) else str(v)
metrics = {"epoch": f"{trainer.current_epoch + 1:>5}"}
super_metrics = super().get_metrics(trainer, pl_module)
super_metrics.pop("v_num", None)
for k, v in super_metrics.items():
metrics[k] = logged_value(v)
# rearrange according to: epoch | time | valid/* | rest...
sorted_metrics = {"epoch": metrics.pop("epoch")}
sorted_metrics["time"] = f"{time.time() - self._start_time:.1f}"
for k in list(metrics):
if k.startswith("valid/"):
sorted_metrics[k] = metrics.pop(k)
sorted_metrics.update(metrics)
return sorted_metrics
[docs]
class WandbLogger(PTLWandbLogger):
"""A subclass of WandbLogger that automatically sets the id and save_dir."""
def __init__(self, output_dir: Path, log_epoch: bool, **kwargs):
if "id" not in kwargs:
kwargs["id"] = output_dir.name
if "save_dir" not in kwargs:
kwargs["save_dir"] = str(output_dir.parent)
super().__init__(**kwargs)
self._kwargs = kwargs
self._log_epoch = log_epoch
def log_metrics(
self, metrics: Mapping[str, float], step: int | None = None
):
if not self._log_epoch:
metrics = {k: v for k, v in metrics.items() if k != "epoch"}
return super().log_metrics(metrics, step)
def __repr__(self):
return uniform_repr(self.__class__.__name__, **self._kwargs)