Aggregation¶
Aggregating some value over one’s neighbours is a common operation in graph-based
ML models. graph-pes
provides a base class for such operations, together with
a few common implementations. A common way to specify the aggregation mode to use
in a model is to use a NeighbourAggregationMode
string, which internally is passed to parse()
.
Base Class¶
- class graph_pes.models.components.aggregation.NeighbourAggregation(*args, **kwargs)[source]¶
An abstract base class for aggregating values over neighbours:
\[X_i^\prime = \text{Agg}_{j \in \mathcal{N}_i} \left[X_j\right]\]where \(\mathcal{N}_i\) is the set of neighbours of atom \(i\), \(X\) has shape
(E, ...)
, \(X^\prime\) has shape(N, ...)
andE
andN
are the number of edges and atoms in the graph, respectively.- pre_fit(graphs)[source]¶
Calculate any quantities that are dependent on the graph structure that should be fixed before prediction.
Default implementation does nothing.
- Parameters:
graphs (AtomicGraph) – A batch of graphs to pre-fit to.
- static parse(mode)[source]¶
Evaluates the following map:
Mode
Aggregation
"sum"
"mean"
"constant_fixed"
"constant_learnable"
"sqrt"
- Parameters:
mode (Literal['sum', 'mean', 'constant_fixed', 'constant_learnable', 'sqrt']) – The neighbour aggregation mode to parse.
- Returns:
The parsed neighbour aggregation mode.
- Return type:
- class graph_pes.models.components.aggregation.NeighbourAggregationMode¶
Type alias for
Literal["sum", "mean", "constant_fixed", "constant_learnable", "sqrt"]
.
Implementations¶
- class graph_pes.models.components.aggregation.SumNeighbours(*args, **kwargs)[source]¶
Sum over neighbours:
\[X_i^\prime = \sum_{j \in \mathcal{N}_i} X_j\]
- class graph_pes.models.components.aggregation.MeanNeighbours(*args, **kwargs)[source]¶
Take an average over neighbours:
\[X_i^\prime = \frac{1}{|\mathcal{N}_i|} \sum_{j \in \mathcal{N}_i} X_j\]where \(|\mathcal{N}_i|\) is the number of neighbours of atom \(i\) (including the central atom).
Note
This aggregation can lead to un-physical discontinuities in the PES as neighbours enter or leave the radial cutoff.
- class graph_pes.models.components.aggregation.ScaledSumNeighbours(learnable=False)[source]¶
Scale the sum over neighbours by a learnable or fixed constant, \(s\):
\[X_i^\prime = \frac{1}{s} \sum_{j \in \mathcal{N}_i} X_j\]\(s\) defaults to 1.0, but is set to the average number of neighbours of each atom in the training set passed to
pre_fit()
.- Parameters:
learnable (bool) – If
True
, the scale is a learnable parameter. IfFalse
, the scale is a fixed constant.
- class graph_pes.models.components.aggregation.VariancePreservingSumNeighbours(*args, **kwargs)[source]¶
Scale the sum over neighbours by the square root of the number of neighbours:
\[X_i^\prime = \frac{1}{\sqrt{|\mathcal{N}_i|}} \sum_{j \in \mathcal{N}_i} X_j\]where \(|\mathcal{N}_i|\) is the number of neighbours of atom \(i\) (including the central atom).
Note
This aggregation can lead to un-physical discontinuities in the PES as neighbours enter or leave the radial cutoff.