Source code for graph_pes.models

from __future__ import annotations

import re
import warnings
from typing import TypeVar

warnings.filterwarnings(
    "ignore",
    module="e3nn",
    message=".*you are using `torch.load`.*",
)

import pathlib

import torch

from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.utils.logger import logger

from .addition import AdditionModel
from .e3nn.mace import MACE, ZEmbeddingMACE
from .e3nn.nequip import NequIP, ZEmbeddingNequIP
from .offsets import FixedOffset, LearnableOffset
from .painn import PaiNN
from .pairwise import (
    LennardJones,
    LennardJonesMixture,
    Morse,
    PairPotential,
    SmoothedPairPotential,
    ZBLCoreRepulsion,
)
from .schnet import SchNet
from .tensornet import TensorNet

__all__ = [
    "AdditionModel",
    "FixedOffset",
    "LearnableOffset",
    "LennardJones",
    "LennardJonesMixture",
    "MACE",
    "Morse",
    "NequIP",
    "PaiNN",
    "PairPotential",
    "SchNet",
    "SmoothedPairPotential",
    "TensorNet",
    "ZBLCoreRepulsion",
    "ZEmbeddingMACE",
    "ZEmbeddingNequIP",
]

MODEL_EXCLUSIONS = {
    "FixedOffset",
    "LearnableOffset",
    "AdditionModel",
    "PairPotential",
    "SmoothedPairPotential",
}

ALL_MODELS: list[type[GraphPESModel]] = [
    globals()[model] for model in __all__ if model not in MODEL_EXCLUSIONS
]


[docs] def load_model(path: str | pathlib.Path) -> GraphPESModel: """ Load a model from a file. Parameters ---------- path The path to the file. Returns ------- GraphPESModel The model. Examples -------- Use this function to load an existing model for further training using ``graph-pes-train``: .. code-block:: yaml model: +load_model: path: path/to/model.pt To account for some new energy offset in your training data, you could do something like this: (see also :func:`~graph_pes.models.load_model_component`) .. code-block:: yaml model: # add an offset to an existing model before fine-tuning offset: +LearnableOffset() many-body: +load_model: path: path/to/model.pt """ path = pathlib.Path(path) if not path.exists(): raise FileNotFoundError(f"Could not find model at {path}") model = torch.load(path, weights_only=False) if not isinstance(model, GraphPESModel): raise ValueError( "Expected the loaded object to be a GraphPESModel " f"but got {type(model)}" ) import graph_pes if model._GRAPH_PES_VERSION != graph_pes.__version__: warnings.warn( "You are attempting to load a model that was trained with " f"a different version of graph-pes ({model._GRAPH_PES_VERSION}) " f"than what you are currently using ({graph_pes.__version__}). " "We won't stop you from doing this, but it may cause issues.", stacklevel=2, ) return model
[docs] def load_model_component( path: str | pathlib.Path, key: str, ) -> GraphPESModel: """ Load a component from an :class:`~graph_pes.models.AdditionModel`. Parameters ---------- path The path to the file. key The key to load. Returns ------- GraphPESModel The component. Examples -------- Train on data with a new energy offset: .. code-block:: yaml model: offset: +LearnableOffset() many-body: +load_model_component: path: path/to/model.pt key: many-body """ base_model = load_model(path) if not isinstance(base_model, AdditionModel): raise ValueError( f"Expected to load an AdditionModel, got {type(base_model)}" ) return base_model[key]
T = TypeVar("T", bound=torch.nn.Module)
[docs] def freeze(model: T) -> T: """ Freeze all parameters in a module. Parameters ---------- model The model to freeze. Returns ------- T The model. """ for param in model.parameters(): param.requires_grad = False return model
[docs] def freeze_matching(model: T, pattern: str) -> T: r""" Freeze all parameters that match the given pattern. Parameters ---------- model The model to freeze. pattern The regular expression to match the names of the parameters to freeze. Returns ------- T The model. Examples -------- Freeze all the parameters in the first layer of a MACE-MP0 model from :func:`~graph_pes.interfaces.mace_mp` (which have names of the form ``"model.interactions.0.<name>"``): .. code-block:: yaml model: +freeze_any_matching: model: +mace_mp() pattern: model\.interactions\.0\..* """ for name, param in model.named_parameters(): if re.match(pattern, name): logger.info(f"Freezing {name}") param.requires_grad = False return model
[docs] def freeze_any_matching(model: T, patterns: list[str]) -> T: r""" Freeze all parameters that match any of the given patterns. Parameters ---------- model The model to freeze. patterns The patterns to match. Returns ------- T The model. """ for pattern in patterns: freeze_matching(model, pattern) return model
[docs] def freeze_all_except(model: T, pattern: str | list[str]) -> T: r""" Freeze all parameters in a model except those matching a given pattern. Parameters ---------- model The model to freeze. pattern The pattern/s to match. Returns ------- T The model. Examples -------- Freeze all parameters in a MACE-MP0 model from :func:`~graph_pes.interfaces.mace_mp` except those in the read-out heads: .. code-block:: yaml model: +freeze_all_except: model: +mace_mp() pattern: model\.readouts.* """ freeze(model) if isinstance(pattern, str): pattern = [pattern] for name, param in model.named_parameters(): if any(re.match(p, name) for p in pattern): param.requires_grad = True return model