from __future__ import annotations
import re
import warnings
from typing import TypeVar
warnings.filterwarnings(
"ignore",
module="e3nn",
message=".*you are using `torch.load`.*",
)
import pathlib
import torch
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.utils.logger import logger
from .addition import AdditionModel
from .e3nn.mace import MACE, ZEmbeddingMACE
from .e3nn.nequip import NequIP, ZEmbeddingNequIP
from .eddp import EDDP
from .offsets import FixedOffset, LearnableOffset
from .painn import PaiNN
from .pairwise import (
LennardJones,
LennardJonesMixture,
Morse,
PairPotential,
SmoothedPairPotential,
ZBLCoreRepulsion,
)
from .schnet import SchNet
from .tensornet import TensorNet
__all__ = [
"AdditionModel",
"EDDP",
"FixedOffset",
"LearnableOffset",
"LennardJones",
"LennardJonesMixture",
"MACE",
"Morse",
"NequIP",
"PaiNN",
"PairPotential",
"SchNet",
"SmoothedPairPotential",
"TensorNet",
"ZBLCoreRepulsion",
"ZEmbeddingMACE",
"ZEmbeddingNequIP",
]
MODEL_EXCLUSIONS = {
"FixedOffset",
"LearnableOffset",
"AdditionModel",
"PairPotential",
"SmoothedPairPotential",
}
ALL_MODELS: list[type[GraphPESModel]] = [
globals()[model] for model in __all__ if model not in MODEL_EXCLUSIONS
]
[docs]
def load_model(path: str | pathlib.Path) -> GraphPESModel:
"""
Load a model from a file.
Parameters
----------
path
The path to the file.
Returns
-------
GraphPESModel
The model.
Examples
--------
Use this function to load an existing model for further training using
``graph-pes-train``:
.. code-block:: yaml
model:
+load_model:
path: path/to/model.pt
To account for some new energy offset in your training data, you could do
something like this:
(see also :func:`~graph_pes.models.load_model_component`)
.. code-block:: yaml
model:
# add an offset to an existing model before fine-tuning
offset: +LearnableOffset()
many-body:
+load_model:
path: path/to/model.pt
"""
path = pathlib.Path(path)
if not path.exists():
raise FileNotFoundError(f"Could not find model at {path}")
model = torch.load(path, weights_only=False)
if not isinstance(model, GraphPESModel):
raise ValueError(
"Expected the loaded object to be a GraphPESModel "
f"but got {type(model)}"
)
import graph_pes
if model._GRAPH_PES_VERSION != graph_pes.__version__:
warnings.warn(
"You are attempting to load a model that was trained with "
f"a different version of graph-pes ({model._GRAPH_PES_VERSION}) "
f"than what you are currently using ({graph_pes.__version__}). "
"We won't stop you from doing this, but it may cause issues.",
stacklevel=2,
)
return model
[docs]
def load_model_component(
path: str | pathlib.Path,
key: str,
) -> GraphPESModel:
"""
Load a component from an :class:`~graph_pes.models.AdditionModel`.
Parameters
----------
path
The path to the file.
key
The key to load.
Returns
-------
GraphPESModel
The component.
Examples
--------
Train on data with a new energy offset:
.. code-block:: yaml
model:
offset: +LearnableOffset()
many-body:
+load_model_component:
path: path/to/model.pt
key: many-body
"""
base_model = load_model(path)
if not isinstance(base_model, AdditionModel):
raise ValueError(
f"Expected to load an AdditionModel, got {type(base_model)}"
)
return base_model[key]
T = TypeVar("T", bound=torch.nn.Module)
[docs]
def freeze(model: T) -> T:
"""
Freeze all parameters in a module.
Parameters
----------
model
The model to freeze.
Returns
-------
T
The model.
"""
for param in model.parameters():
param.requires_grad = False
return model
[docs]
def freeze_matching(model: T, pattern: str) -> T:
r"""
Freeze all parameters that match the given pattern.
Parameters
----------
model
The model to freeze.
pattern
The regular expression to match the names of the parameters to freeze.
Returns
-------
T
The model.
Examples
--------
Freeze all the parameters in the first layer of a MACE-MP0 model from
:func:`~graph_pes.interfaces.mace_mp` (which have names of the form
``"model.interactions.0.<name>"``):
.. code-block:: yaml
model:
+freeze_any_matching:
model: +mace_mp()
pattern: model\.interactions\.0\..*
"""
for name, param in model.named_parameters():
if re.match(pattern, name):
logger.info(f"Freezing {name}")
param.requires_grad = False
return model
[docs]
def freeze_any_matching(model: T, patterns: list[str]) -> T:
r"""
Freeze all parameters that match any of the given patterns.
Parameters
----------
model
The model to freeze.
patterns
The patterns to match.
Returns
-------
T
The model.
"""
for pattern in patterns:
freeze_matching(model, pattern)
return model
[docs]
def freeze_all_except(model: T, pattern: str | list[str]) -> T:
r"""
Freeze all parameters in a model except those matching a given pattern.
Parameters
----------
model
The model to freeze.
pattern
The pattern/s to match.
Returns
-------
T
The model.
Examples
--------
Freeze all parameters in a MACE-MP0 model from
:func:`~graph_pes.interfaces.mace_mp` except those in the
read-out heads:
.. code-block:: yaml
model:
+freeze_all_except:
model: +mace_mp()
pattern: model\.readouts.*
"""
freeze(model)
if isinstance(pattern, str):
pattern = [pattern]
for name, param in model.named_parameters():
if any(re.match(p, name) for p in pattern):
param.requires_grad = True
return model