
Aggregating some value over one’s neighbours is a common operation in graph-based ML models. graph-pes provides a base class for such operations, together with a few common implementations. A common way to specify the aggregation mode to use in a model is to use a NeighbourAggregationMode string, which internally is passed to parse().

Base Class

class graph_pes.models.components.aggregation.NeighbourAggregation(*args, **kwargs)[source]

An abstract base class for aggregating values over neighbours:

\[X_i^\prime = \text{Agg}_{j \in \mathcal{N}_i} \left[X_j\right]\]

where \(\mathcal{N}_i\) is the set of neighbours of atom \(i\), \(X\) has shape (E, ...), \(X^\prime\) has shape (N, ...) and E and N are the number of edges and atoms in the graph, respectively.

abstract forward(x, graph)[source]

Aggregate x over neighbours.

Return type:



Calculate any quantities that are dependent on the graph structure that should be fixed before prediction.

Default implementation does nothing.


graphs (AtomicGraph) – A batch of graphs to pre-fit to.

static parse(mode)[source]

Evaluates the following map:


mode (Literal['sum', 'mean', 'constant_fixed', 'constant_learnable', 'sqrt']) – The neighbour aggregation mode to parse.


The parsed neighbour aggregation mode.

Return type:


class graph_pes.models.components.aggregation.NeighbourAggregationMode

Type alias for Literal["sum", "mean", "constant_fixed", "constant_learnable", "sqrt"].


class graph_pes.models.components.aggregation.SumNeighbours(*args, **kwargs)[source]

Sum over neighbours:

\[X_i^\prime = \sum_{j \in \mathcal{N}_i} X_j\]
forward(x, graph)[source]

Aggregate x over neighbours.

Return type:


class graph_pes.models.components.aggregation.MeanNeighbours(*args, **kwargs)[source]

Take an average over neighbours:

\[X_i^\prime = \frac{1}{|\mathcal{N}_i|} \sum_{j \in \mathcal{N}_i} X_j\]

where \(|\mathcal{N}_i|\) is the number of neighbours of atom \(i\) (including the central atom).


This aggregation can lead to un-physical discontinuities in the PES as neighbours enter or leave the radial cutoff.

forward(x, graph)[source]

Aggregate x over neighbours.

Return type:


class graph_pes.models.components.aggregation.ScaledSumNeighbours(learnable=False)[source]

Scale the sum over neighbours by a learnable or fixed constant, \(s\):

\[X_i^\prime = \frac{1}{s} \sum_{j \in \mathcal{N}_i} X_j\]

\(s\) defaults to 1.0, but is set to the average number of neighbours of each atom in the training set passed to pre_fit().


learnable (bool) – If True, the scale is a learnable parameter. If False, the scale is a fixed constant.

forward(x, graph)[source]

Aggregate x over neighbours.

Return type:



Set the scale equal to the average number of neighbours in the training set.

class graph_pes.models.components.aggregation.VariancePreservingSumNeighbours(*args, **kwargs)[source]

Scale the sum over neighbours by the square root of the number of neighbours:

\[X_i^\prime = \frac{1}{\sqrt{|\mathcal{N}_i|}} \sum_{j \in \mathcal{N}_i} X_j\]

where \(|\mathcal{N}_i|\) is the number of neighbours of atom \(i\) (including the central atom).


This aggregation can lead to un-physical discontinuities in the PES as neighbours enter or leave the radial cutoff.

forward(x, graph)[source]

Aggregate x over neighbours.

Return type:
