from __future__ import annotations
import re
import warnings
from typing import TypeVar
warnings.filterwarnings(
"ignore",
module="e3nn",
message=".*you are using `torch.load`.*",
)
import pathlib
import torch
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.utils.logger import logger
from .addition import AdditionModel
from .e3nn.mace import MACE, ZEmbeddingMACE
from .e3nn.nequip import NequIP, ZEmbeddingNequIP
from .offsets import FixedOffset, LearnableOffset
from .painn import PaiNN
from .pairwise import (
LennardJones,
LennardJonesMixture,
Morse,
PairPotential,
SmoothedPairPotential,
ZBLCoreRepulsion,
)
from .schnet import SchNet
from .tensornet import TensorNet
__all__ = [
"AdditionModel",
"FixedOffset",
"LearnableOffset",
"LennardJones",
"LennardJonesMixture",
"MACE",
"Morse",
"NequIP",
"PaiNN",
"PairPotential",
"SchNet",
"SmoothedPairPotential",
"TensorNet",
"ZBLCoreRepulsion",
"ZEmbeddingMACE",
"ZEmbeddingNequIP",
]
MODEL_EXCLUSIONS = {
"FixedOffset",
"LearnableOffset",
"AdditionModel",
"PairPotential",
"SmoothedPairPotential",
}
ALL_MODELS: list[type[GraphPESModel]] = [
globals()[model] for model in __all__ if model not in MODEL_EXCLUSIONS
]
[docs]
def load_model(path: str | pathlib.Path) -> GraphPESModel:
"""
Load a model from a file.
Parameters
----------
path
The path to the file.
Returns
-------
GraphPESModel
The model.
Examples
--------
Use this function to load an existing model for further training using
``graph-pes-train``:
.. code-block:: yaml
model:
+load_model:
path: path/to/model.pt
To account for some new energy offset in your training data, you could do
something like this:
(see also :func:`~graph_pes.models.load_model_component`)
.. code-block:: yaml
model:
# add an offset to an existing model before fine-tuning
offset: +LearnableOffset()
many-body:
+load_model:
path: path/to/model.pt
"""
path = pathlib.Path(path)
if not path.exists():
raise FileNotFoundError(f"Could not find model at {path}")
model = torch.load(path, weights_only=False)
if not isinstance(model, GraphPESModel):
raise ValueError(
"Expected the loaded object to be a GraphPESModel "
f"but got {type(model)}"
)
import graph_pes
if model._GRAPH_PES_VERSION != graph_pes.__version__:
warnings.warn(
"You are attempting to load a model that was trained with "
f"a different version of graph-pes ({model._GRAPH_PES_VERSION}) "
f"than what you are currently using ({graph_pes.__version__}). "
"We won't stop you from doing this, but it may cause issues.",
stacklevel=2,
)
return model
[docs]
def load_model_component(
path: str | pathlib.Path,
key: str,
) -> GraphPESModel:
"""
Load a component from an :class:`~graph_pes.models.AdditionModel`.
Parameters
----------
path
The path to the file.
key
The key to load.
Returns
-------
GraphPESModel
The component.
Examples
--------
Train on data with a new energy offset:
.. code-block:: yaml
model:
offset: +LearnableOffset()
many-body:
+load_model_component:
path: path/to/model.pt
key: many-body
"""
base_model = load_model(path)
if not isinstance(base_model, AdditionModel):
raise ValueError(
f"Expected to load an AdditionModel, got {type(base_model)}"
)
return base_model[key]
T = TypeVar("T", bound=torch.nn.Module)
[docs]
def freeze(model: T) -> T:
"""
Freeze all parameters in a module.
Parameters
----------
model
The model to freeze.
Returns
-------
T
The model.
"""
for param in model.parameters():
param.requires_grad = False
return model
[docs]
def freeze_matching(model: T, pattern: str) -> T:
r"""
Freeze all parameters that match the given pattern.
Parameters
----------
model
The model to freeze.
pattern
The regular expression to match the names of the parameters to freeze.
Returns
-------
T
The model.
Examples
--------
Freeze all the parameters in the first layer of a MACE-MP0 model from
:func:`~graph_pes.interfaces.mace_mp` (which have names of the form
``"model.interactions.0.<name>"``):
.. code-block:: yaml
model:
+freeze_any_matching:
model: +mace_mp()
pattern: model\.interactions\.0\..*
"""
for name, param in model.named_parameters():
if re.match(pattern, name):
logger.info(f"Freezing {name}")
param.requires_grad = False
return model
[docs]
def freeze_any_matching(model: T, patterns: list[str]) -> T:
r"""
Freeze all parameters that match any of the given patterns.
Parameters
----------
model
The model to freeze.
patterns
The patterns to match.
Returns
-------
T
The model.
"""
for pattern in patterns:
freeze_matching(model, pattern)
return model
[docs]
def freeze_all_except(model: T, pattern: str | list[str]) -> T:
r"""
Freeze all parameters in a model except those matching a given pattern.
Parameters
----------
model
The model to freeze.
pattern
The pattern/s to match.
Returns
-------
T
The model.
Examples
--------
Freeze all parameters in a MACE-MP0 model from
:func:`~graph_pes.interfaces.mace_mp` except those in the
read-out heads:
.. code-block:: yaml
model:
+freeze_all_except:
model: +mace_mp()
pattern: model\.readouts.*
"""
freeze(model)
if isinstance(pattern, str):
pattern = [pattern]
for name, param in model.named_parameters():
if any(re.match(p, name) for p in pattern):
param.requires_grad = True
return model