from__future__importannotationsimportpathlibimporte3nn.util.jitimporttorchfromgraph_pes.atomic_graphimportAtomicGraph,PropertyKeyfromgraph_pes.graph_pes_modelimportGraphPESModelfromgraph_pes.utils.miscimportfull_3x3_to_voigt_6defas_lammps_data(graph:AtomicGraph,compute_virial:bool=False,debug:bool=False,)->dict[str,torch.Tensor]:return{"atomic_numbers":graph.Z,"positions":graph.R,"cell":graph.cell,"neighbour_list":graph.neighbour_list,"neighbour_cell_offsets":graph.neighbour_cell_offsets,"compute_virial":torch.tensor(compute_virial),"debug":torch.tensor(debug),}classLAMMPSModel(torch.nn.Module):def__init__(self,model:GraphPESModel):super().__init__()self.model=model@torch.jit.export# type: ignoredefget_cutoff(self)->torch.Tensor:returnself.model.cutoffdefforward(self,graph_data:dict[str,torch.Tensor])->dict[str,torch.Tensor]:debug=graph_data.get("debug",torch.tensor(False)).item()ifdebug:print("Received graph:")forkey,valueingraph_data.items():print(f"{key}: {value}")compute_virial=graph_data["compute_virial"].item()properties:list[PropertyKey]=["energy","forces","local_energies"]ifcompute_virial:properties.append("virial")# graph_data is a dict, so we need to convert it to an AtomicGraphgraph=AtomicGraph(Z=graph_data["atomic_numbers"],R=graph_data["positions"],cell=graph_data["cell"],neighbour_list=graph_data["neighbour_list"],neighbour_cell_offsets=graph_data["neighbour_cell_offsets"],properties={},other={},cutoff=self.model.cutoff.item(),)preds=self.model.predict(graph,properties=properties)# cast to float64forkeyinpreds:preds[key]=preds[key].double()# add virial output if requiredifcompute_virial:# LAMMPS expects the **virial** in Voigt notation# we provide the **stress** in full 3x3 matrix notation# therefore, convert:preds["virial"]=full_3x3_to_voigt_6(preds["virial"])returnpreds# type: ignoredef__call__(self,graph_data:dict[str,torch.Tensor])->dict[str,torch.Tensor]:returnsuper().__call__(graph_data)
[docs]defdeploy_model(model:GraphPESModel,path:str|pathlib.Path):""" Deploy a :class:`~graph_pes.GraphPESModel` for use with LAMMPS. Use the resulting model with LAMMPS according to: .. code-block:: bash pair_style graph_pes <cpu> pair_coeff * * path/to/model.pt <element-of-type-1> <element-of-type-2> ... Parameters ---------- model The model to deploy. path The path to save the deployed model to. """# noqa: E501lammps_model=LAMMPSModel(model)scripted_model=e3nn.util.jit.script(lammps_model)torch.jit.save(scripted_model,path)