from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from ase.geometry.cell import cell_to_cellpar
from graph_pes.atomic_graph import (
AtomicGraph,
PropertyKey,
edges_per_graph,
is_batch,
neighbour_distances,
neighbour_vectors,
number_of_atoms,
structure_sizes,
to_batch,
trim_edges,
)
from graph_pes.graph_pes_model import GraphPESModel
from graph_pes.utils.misc import voigt_6_to_full_3x3
if TYPE_CHECKING:
from orb_models.forcefield.conservative_regressor import (
ConservativeForcefieldRegressor,
)
from orb_models.forcefield.direct_regressor import (
DirectForcefieldRegressor,
)
def from_graph_pes_to_orb_batch(
graph: AtomicGraph,
cutoff: float,
max_neighbours: int,
):
from orb_models.forcefield.base import AtomGraphs as OrbGraph
if not is_batch(graph):
graph = to_batch([graph])
graph = trim_edges(graph, cutoff)
distances = neighbour_distances(graph)
new_nl, new_offsets = [], []
for i in range(number_of_atoms(graph)):
mask = graph.neighbour_list[0] == i
d = distances[mask]
if d.numel() == 0:
continue
elif d.numel() < max_neighbours:
new_nl.append(graph.neighbour_list[:, mask])
new_offsets.append(graph.neighbour_cell_offsets[mask])
else:
topk = torch.topk(d, k=max_neighbours, largest=False)
new_nl.append(graph.neighbour_list[:, mask][:, topk.indices])
new_offsets.append(graph.neighbour_cell_offsets[mask][topk.indices])
graph = graph._replace(
neighbour_list=torch.hstack(new_nl),
neighbour_cell_offsets=torch.vstack(new_offsets),
)
node_features = {
"atomic_numbers": graph.Z.long(),
"positions": graph.R,
"atomic_numbers_embedding": torch.nn.functional.one_hot(
graph.Z, num_classes=118
),
"atom_identity": torch.arange(number_of_atoms(graph)).long(),
}
edge_features = {
"vectors": neighbour_vectors(graph),
"unit_shifts": graph.neighbour_cell_offsets,
}
lattices = []
for cell in graph.cell.clone().detach():
lattices.append(
torch.from_numpy(cell_to_cellpar(cell.cpu().numpy())).float()
)
lattice = torch.vstack(lattices).to(graph.R.device)
graph_features = {
"cell": graph.cell,
"pbc": torch.Tensor([False, False, False])
if not graph.pbc
else graph.pbc,
"lattice": lattice,
}
return OrbGraph(
senders=graph.neighbour_list[0],
receivers=graph.neighbour_list[1],
n_node=structure_sizes(graph),
n_edge=edges_per_graph(graph),
node_features=node_features,
edge_features=edge_features,
system_features=graph_features,
node_targets={},
edge_targets={},
system_targets={},
fix_atoms=None,
tags=torch.zeros(number_of_atoms(graph)),
radius=cutoff,
max_num_neighbors=torch.tensor([max_neighbours]),
system_id=None,
).to(device=graph.R.device, dtype=graph.R.dtype)
[docs]
class OrbWrapper(GraphPESModel):
"""
A wrapper around an ``orb-models`` model that converts it into a
:class:`~graph_pes.GraphPESModel`.
Parameters
----------
orb
The ``orb-models`` model to wrap.
"""
def __init__(self, orb: torch.nn.Module):
from orb_models.forcefield.conservative_regressor import (
ConservativeForcefieldRegressor,
)
from orb_models.forcefield.direct_regressor import (
DirectForcefieldRegressor,
)
assert isinstance(
orb, (DirectForcefieldRegressor, ConservativeForcefieldRegressor)
)
super().__init__(
cutoff=orb.system_config.radius,
implemented_properties=[
"local_energies",
"energy",
"forces",
"stress",
],
)
self._orb = orb
def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]:
orb_graph = from_graph_pes_to_orb_batch(
graph,
self._orb.system_config.radius,
self._orb.system_config.max_num_neighbors,
)
preds: dict[PropertyKey, torch.Tensor] = self._orb.predict(orb_graph) # type: ignore
if "grad_forces" in preds:
preds["forces"] = preds.pop("grad_forces") # type: ignore
if "grad_stress" in preds:
preds["stress"] = preds.pop("grad_stress") # type: ignore
if "stress" in preds:
preds["stress"] = voigt_6_to_full_3x3(preds["stress"])
# underlying orb model returns things in batched format.
# we want to de-batch things if only a single graph is provided.
if not is_batch(graph):
preds["energy"] = preds["energy"][0]
preds["local_energies"] = torch.zeros(number_of_atoms(graph)).to(
graph.Z.device
)
return preds
@property
def orb_model(
self,
) -> "DirectForcefieldRegressor | ConservativeForcefieldRegressor":
r"""
Access the underlying ``orb-models`` model.
One use case of this is to to use ``graph-pes``\ 's fine-tuning
functionality to adapt an existing ``orb-models`` model to a new
dataset. You can then re-extract the underlying ``orb-models`` model
using this property and use it in other ``orb-models`` workflows.
"""
return self._orb
[docs]
def orb_model(name: str = "orb-v3-direct-20-omat") -> OrbWrapper:
"""
Load a pre-trained Orb model, and convert it into a
:class:`~graph_pes.GraphPESModel`.
See the `orb-models <https://github.com/orbital-materials/orb-models>`_
repository for more information on the available models.
As of 2025-04-11, the following are available:
* ``"orb-v3-conservative-20-omat"``
* ``"orb-v3-conservative-inf-omat"``
* ``"orb-v3-direct-20-omat"``
* ``"orb-v3-direct-inf-omat"``
* ``"orb-v3-conservative-20-mpa"``
* ``"orb-v3-conservative-inf-mpa"``
* ``"orb-v3-direct-20-mpa"``
* ``"orb-v3-direct-inf-mpa"``
* ``"orb-v2"``
* ``"orb-d3-v2"``
* ``"orb-d3-sm-v2"``
* ``"orb-d3-xs-v2"``
* ``"orb-mptraj-only-v2"``
Parameters
----------
name: str
The name of the model to load.
"""
import torch._functorch.config
from orb_models.forcefield import pretrained
torch._functorch.config.donated_buffer = False
orb = pretrained.ORB_PRETRAINED_MODELS[name](device="cpu")
for param in orb.parameters():
param.requires_grad = True
return OrbWrapper(orb)