from __future__ import annotations
import warnings
from contextlib import redirect_stdout
from itertools import chain
from pathlib import Path
from typing import Literal
import requests
import torch
from graph_pes import AtomicGraph
from graph_pes.atomic_graph import PropertyKey, is_batch
from graph_pes.interfaces.base import InterfaceModel
from graph_pes.utils.misc import MAX_Z
class ZToOneHot(torch.nn.Module):
def __init__(self, elements: list[int]):
super().__init__()
self.z_to_index: torch.Tensor
self.register_buffer("z_to_index", torch.full((MAX_Z + 1,), -1))
for i, z in enumerate(elements):
self.z_to_index[z] = i
self.num_classes = len(elements)
def forward(self, Z: torch.Tensor) -> torch.Tensor:
if (Z > 118).any():
not_allowed: list[int] = torch.unique(Z[Z > 118]).tolist()
raise ValueError(
"ZToOneHot received an atomic number is too large: "
f"{not_allowed}."
)
indices = self.z_to_index[Z]
if (indices < 0).any():
not_allowed: list[int] = torch.unique(Z[indices < 0]).tolist()
raise ValueError(
"ZToOneHot received an atomic number that is not in the model's"
f" element list: {not_allowed}. Please ensure the model was "
"trained with all elements present in the input graph."
)
return torch.nn.functional.one_hot(indices, self.num_classes)
def _atomic_graph_to_mace_input(
graph: AtomicGraph,
z_to_one_hot: ZToOneHot,
) -> dict[str, torch.Tensor]:
batch = graph.batch
if batch is None:
batch = torch.zeros_like(graph.Z)
ptr = graph.ptr
if ptr is None:
ptr = torch.tensor([0, graph.Z.shape[0]])
cell = graph.cell.unsqueeze(0) if not is_batch(graph) else graph.cell
_cell_per_edge = cell[batch[graph.neighbour_list[0]]] # (E, 3, 3)
_shifts = torch.einsum(
"kl,klm->km", graph.neighbour_cell_offsets, _cell_per_edge
) # (E, 3)
data = {
"node_attrs": z_to_one_hot.forward(graph.Z).to(graph.R.dtype),
"positions": graph.R.clone().detach(),
"cell": cell.clone().detach(),
"edge_index": graph.neighbour_list,
"unit_shifts": graph.neighbour_cell_offsets,
"shifts": _shifts,
"batch": batch,
"ptr": ptr,
}
return {k: v.to(graph.Z.device) for k, v in data.items()}
[docs]
class MACEWrapper(InterfaceModel):
"""
Converts any MACE model from the `mace-torch <https://github.com/ACEsuit/mace-torch>`__
package into a :class:`~graph_pes.GraphPESModel`.
You can use this to drive MD using LAMMPS, fine-tune MACE models,
or any functionality that ``graph-pes`` provides.
Parameters
----------
model
The MACE model to wrap.
Examples
--------
>>> mace_torch_model = ... # create your MACE model any-which way
>>> from graph_pes.interfaces._mace import MACEWrapper
>>> graph_pes_model = MACEWrapper(mace_torch_model) # convert to graph-pes
>>> graph_pes_model.predict_energy(graph)
torch.Tensor([123.456])
>>> from graph_pes.utils.calculator import GraphPESCalculator
>>> calculator = GraphPESCalculator(graph_pes_model)
>>> calculator.calculate(ase_atoms)
"""
def __init__(self, model: torch.nn.Module):
super().__init__(
model.r_max.item(), # type: ignore
implemented_properties=[
"local_energies",
"energy",
"forces",
"stress",
"virial",
],
)
self.model = model
self.z_to_one_hot = ZToOneHot(self.model.atomic_numbers.tolist()) # type: ignore
def convert_to_underlying_input(
self, graph: AtomicGraph
) -> dict[str, torch.Tensor]:
return _atomic_graph_to_mace_input(graph, self.z_to_one_hot)
def raw_forward_pass(
self,
input: dict[str, torch.Tensor],
is_batched: bool,
properties: list[PropertyKey],
) -> dict[PropertyKey, torch.Tensor]:
MACE_KEY_MAPPING: dict[str, PropertyKey] = {
"node_energy": "local_energies",
"energy": "energy",
"forces": "forces",
"stress": "stress",
"virials": "virial",
}
raw_predictions = self.model.forward(
input,
training=self.training,
compute_force="forces" in properties,
compute_stress="stress" in properties,
compute_virials="virial" in properties,
)
predictions: dict[PropertyKey, torch.Tensor] = {}
for key, value in raw_predictions.items():
if key in MACE_KEY_MAPPING:
property_key = MACE_KEY_MAPPING[key]
if property_key in properties and value is not None:
predictions[property_key] = value
if not is_batched:
for p in ["energy", "stress", "virial"]:
if p in properties:
predictions[p] = predictions[p].squeeze()
return predictions
def predict(
self,
graph: AtomicGraph,
properties: list[PropertyKey],
) -> dict[PropertyKey, torch.Tensor]:
# override this predict property to use the underlying
# mace-torch mechanisms for calculating forces etc.
input = self.convert_to_underlying_input(graph)
is_batched = is_batch(graph)
return self.raw_forward_pass(input, is_batched, properties)
def _fix_dtype(model: torch.nn.Module, dtype: torch.dtype) -> None:
for tensor in chain(
model.parameters(),
model.buffers(),
):
if tensor.dtype.is_floating_point:
tensor.data = tensor.data.to(dtype)
def _get_dtype(
precision: Literal["float32", "float64"] | None,
) -> torch.dtype:
if precision is None:
return torch.get_default_dtype()
return {"float32": torch.float32, "float64": torch.float64}[precision]
[docs]
def mace_mp(
model: str = "small",
precision: Literal["float32", "float64"] | None = None,
) -> MACEWrapper:
"""
Donwload a MACE-MP model and convert it for use with ``graph-pes``.
Internally, we use the `foundation_models <https://mace-docs.readthedocs.io/en/latest/guide/foundation_models.html>`__
functionality from the `mace-torch <https://github.com/ACEsuit/mace-torch>`__ package.
Please cite the following if you use this model:
- MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena,
Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096
- MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b,
DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal
- Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner,
Yuan Chiang, Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023,
arXiv:2308.14920
As of 10th April 2025, the following models are available: ``["small", "medium", "large", "medium-mpa-0", "small-0b", "medium-0b", "small-0b2", "medium-0b2", "medium-0b3", "large-0b2", "medium-omat-0"]``
Parameters
----------
model
The size of the MACE-MP model to download.
precision
The precision of the model. If ``None``, the default precision
of torch will be used (you can set this when using ``graph-pes-train``
via ``general/torch/dtype``)
""" # noqa: E501
from mace.calculators.foundations_models import mace_mp as mace_mp_impl
dtype = _get_dtype(precision)
precision_str = {torch.float32: "float32", torch.float64: "float64"}[dtype]
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
with redirect_stdout(None):
mace_torch_model = mace_mp_impl(
model,
device="cpu",
default_dtype=precision_str,
return_raw_model=True,
)
assert isinstance(mace_torch_model, torch.nn.Module)
_fix_dtype(mace_torch_model, dtype)
return MACEWrapper(mace_torch_model)
[docs]
def mace_off(
model: Literal["small", "medium", "large"],
precision: Literal["float32", "float64"] | None = None,
) -> MACEWrapper:
"""
Download a MACE-OFF model and convert it for use with ``graph-pes``.
If you use this model, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211
Parameters
----------
model
The size of the MACE-OFF model to download.
precision
The precision of the model.
""" # noqa: E501
from mace.calculators.foundations_models import mace_off
dtype = _get_dtype(precision)
precision_str = {torch.float32: "float32", torch.float64: "float64"}[dtype]
mace_torch_model = mace_off(
model,
device="cpu",
default_dtype=precision_str,
return_raw_model=True,
)
assert isinstance(mace_torch_model, torch.nn.Module)
_fix_dtype(mace_torch_model, dtype)
# un freeze all parameters
for p in mace_torch_model.parameters():
p.requires_grad = True
return MACEWrapper(mace_torch_model)
[docs]
def go_mace_23(
precision: Literal["float32", "float64"] | None = None,
) -> MACEWrapper:
"""
Download the `GO-MACE-23 model <https://doi.org/10.1002/anie.202410088>`__
and convert it for use with ``graph-pes``.
.. note::
This model is only for use on structures containing Carbon, Hydrogen and
Oxygen. Attempting to use on structures with other elements will raise
an error.
If you use this model, please cite the following:
.. code-block:: bibtex
@article{El-Machachi-24,
title = {Accelerated {{First-Principles Exploration}} of {{Structure}} and {{Reactivity}} in {{Graphene Oxide}}},
author = {{El-Machachi}, Zakariya and Frantzov, Damyan and Nijamudheen, A. and Zarrouk, Tigany and Caro, Miguel A. and Deringer, Volker L.},
year = {2024},
journal = {Angewandte Chemie International Edition},
volume = {63},
number = {52},
pages = {e202410088},
doi = {10.1002/anie.202410088},
}
""" # noqa: E501
return _mace_from_url(
"https://github.com/zakmachachi/GO-MACE-23/raw/refs/heads/main/models/fitting/potential/iter-12-final-model/go-mace-23.pt",
"GO-MACE-23",
precision,
)
[docs]
def egret(
model: Literal["egret-1", "egret-1t", "egret-1e"] = "egret-1",
) -> MACEWrapper:
"""
Download an `Egret <https://arxiv.org/abs/2504.20955>`__ model
and convert it for use with ``graph-pes``.
Use the ``egret-1`` model via the Python API:
.. code-block:: python
from graph_pes.interfaces._mace import egret
model = egret("egret-1")
or :doc:`fine-tune <../quickstart/fine-tuning>` it on your own data using
the :doc:`graph-pes-train <../cli/graph-pes-train/root>` command:
.. code-block:: yaml
model:
+egret: {model: "egret-1e"}
data:
...
# etc.
If you use this model, please cite the following:
.. code-block:: bibtex
@misc{Mann-25-04,
title = {
Egret-1: {{Pretrained Neural Network
Potentials For Efficient}} and
{{Accurate Bioorganic Simulation}}
},
author = {
Mann, Elias L. and Wagen, Corin C.
and Vandezande, Jonathon E. and Wagen, Arien M.
and Schneider, Spencer C.
},
year = {2025},
number = {arXiv:2504.20955},
doi = {10.48550/arXiv.2504.20955},
}
As of 1st May 2025, the following models are available: ``["egret-1", "egret-1t", "egret-1e"]``
Parameters
----------
model
The model to download.
""" # noqa: E501
urls = {
"egret-1": "https://github.com/rowansci/egret-public/raw/b1b4c1261315b38f0dd3f6ec0fce891a9119ffe0/compiled_models/EGRET_1.model",
"egret-1t": "https://github.com/rowansci/egret-public/raw/b1b4c1261315b38f0dd3f6ec0fce891a9119ffe0/compiled_models/EGRET_1T.model",
"egret-1e": "https://github.com/rowansci/egret-public/raw/b1b4c1261315b38f0dd3f6ec0fce891a9119ffe0/compiled_models/EGRET_1E.model",
}
return _mace_from_url(urls[model], model)
def _mace_from_url(
url: str,
model_name: str,
precision: Literal["float32", "float64"] | None = None,
) -> MACEWrapper:
dtype = _get_dtype(precision)
file_name = f"{model_name.replace(' ', '-')}.pt"
save_path = Path.home() / ".graph-pes" / file_name
save_path.parent.mkdir(parents=True, exist_ok=True)
if not save_path.exists():
print(f"Downloading {model_name} model to {save_path}")
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
with open(save_path, "wb") as file:
file.write(response.content)
print(f"Loading {model_name} model from {save_path}")
mace_torch_model = torch.load(
save_path, weights_only=False, map_location=torch.device("cpu")
)
for p in mace_torch_model.parameters():
p.data = p.data.to(dtype)
model = MACEWrapper(mace_torch_model)
return model