[docs]classLocalEnergiesScaler(nn.Module):""" Scale the local energies by a per-element scaling factor. See :func:`~graph_pes.utils.shift_and_scale.guess_per_element_mean_and_var` for how the scaling factors are estimated from the training data. """def__init__(self):super().__init__()self.per_element_scaling=PerElementParameter.of_length(1,default_value=1.0,requires_grad=True,)""" The per-element scaling factors. (:class:`~graph_pes.utils.nn.PerElementParameter`) """
[docs]defforward(self,local_energies:torch.Tensor,graph:AtomicGraph,)->torch.Tensor:""" Scale the local energies by the per-element scaling factor. """scales=self.per_element_scaling[graph.Z].squeeze()returnlocal_energies.squeeze()*scales
# add typing for mypy etcdef__call__(self,local_energies:torch.Tensor,graph:AtomicGraph)->torch.Tensor:returnsuper().__call__(local_energies,graph)
[docs]@torch.no_grad()defpre_fit(self,graphs:AtomicGraph):""" Pre-fit the per-element scaling factors. Parameters ---------- graphs The training data. """if"energy"notingraphs.properties:warnings.warn("No energy data found in training data: can't estimate ""per-element scaling factors for local energies.",stacklevel=2,)returnmeans,variances=guess_per_element_mean_and_var(graphs.properties["energy"],graphs)forZ,varinvariances.items():self.per_element_scaling[Z]=torch.sqrt(torch.tensor(var))
[docs]defnon_decayable_parameters(self)->list[torch.nn.Parameter]:"""The ``per_element_scaling`` parameter should not be decayed."""return[self.per_element_scaling]