Source code for graph_pes.utils.threebody

from __future__ import annotations

import torch

from graph_pes.atomic_graph import (
    AtomicGraph,
    get_vectors,
    neighbour_distances,
    neighbour_vectors,
    number_of_atoms,
    number_of_edges,
)


def angle_spanned_by(v1: torch.Tensor, v2: torch.Tensor):
    """
    Calculate angles between corresponding vectors in two batches.

    Parameters
    ----------
    v1
        First batch of vectors, shape (N, 3)
    v2
        Second batch of vectors, shape (N, 3)

    Returns
    -------
    torch.Tensor
        Angles in radians, shape (N,)
    """
    # Compute dot product
    dot_product = torch.sum(v1 * v2, dim=1)

    # Compute magnitudes
    v1_mag = torch.linalg.vector_norm(v1, dim=1)
    v2_mag = torch.linalg.vector_norm(v2, dim=1)

    # Compute cosine of angle, add small epsilon to prevent division by zero
    cos_angle = dot_product / (v1_mag * v2_mag)

    # Clamp cosine values to handle numerical instabilities
    cos_angle = torch.clamp(cos_angle, min=-1.0 + 1e-7, max=1.0 - 1e-7)

    # Compute angle using arccos
    return torch.arccos(cos_angle)


[docs] def triplet_bond_descriptors( graph: AtomicGraph, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r""" For each triplet :math:`(i, j, k)`, get the bond angle :math:`\theta_{jik}` (in radians) and the two bond lengths :math:`r_{ij}` and :math:`r_{ik}`. Returns ------- triplet_idxs The triplet indices, :math:`(i, j, k)`, of shape ``(Y, 3)``. angle The bond angle :math:`\theta_{jik}`, shape ``(Y,)``. r_ij The bond length :math:`r_{ij}`, shape ``(Y,)``. r_ik The bond length :math:`r_{ik}`, shape ``(Y,)``. Examples -------- >>> graph = AtomicGraph.from_ase(molecule("H2O")) >>> angle, r_ij, r_ik = triplet_bond_descriptors(graph) >>> torch.rad2deg(angle) tensor([103.9999, 103.9999, 38.0001, 38.0001, 38.0001, 38.0001]) """ edge_pairs = triplet_edge_pairs(graph, graph.cutoff) # (Y, 2) ij = graph.neighbour_list[:, edge_pairs[:, 0]] # (2, Y) k = graph.neighbour_list[1, edge_pairs[:, 1]].unsqueeze(0) # (1, Y) triplet_idxs = torch.cat([ij, k], dim=0).transpose(0, 1) # (Y, 3) if triplet_idxs.shape[0] == 0: return ( triplet_idxs, torch.zeros(0, device=graph.R.device).float(), torch.zeros(0, device=graph.R.device).float(), torch.zeros(0, device=graph.R.device).float(), ) v = neighbour_vectors(graph) v1 = v[edge_pairs[:, 0]] v2 = v[edge_pairs[:, 1]] return ( triplet_idxs, angle_spanned_by(v1, v2), torch.linalg.vector_norm(v1, dim=-1), torch.linalg.vector_norm(v2, dim=-1), )
[docs] def triplet_edge_pairs( graph: AtomicGraph, three_body_cutoff: float, ) -> torch.Tensor: r""" Find all the pairs of edges, :math:`a = (i, j), b = (i, k)`, such that: * :math:`i, j, k \in \{0, 1, \dots, N-1\}` are indices of distinct (images of) atoms within the graph * :math:`j \neq k` * :math:`r_{ij} \leq` ``three_body_cutoff`` * :math:`r_{ik} \leq` ``three_body_cutoff`` Returns ------- edge_pairs: torch.Tensor A ``(Y, 2)`` shaped tensor indicating the edges, such that .. code-block:: python a, b = edge_pairs[y] i, j = graph.neighbour_list[:,a] i, k = graph.neighbour_list[:,b] """ if three_body_cutoff > graph.cutoff + 1e-6: raise ValueError( "Three-body cutoff is greater than the graph cutoff. " "This is not currently supported." ) # check if already cached, using old .format to be torchscript compatible # NB this gets added in the to_batch function, which is called on the worker # threads. Since this function is slow, this speeds up training, but # should not be used for MD/inference. Hence we don't cache any results # to the graph within this function. key = "__threebody-{:.3f}".format(three_body_cutoff) # noqa: UP032 if key in graph.other: v = graph.other.get(key) if v is not None: return v with torch.no_grad(): edge_indexes = torch.arange( number_of_edges(graph), device=graph.R.device ) three_body_mask = neighbour_distances(graph) < three_body_cutoff relevant_edge_indexes = edge_indexes[three_body_mask] relevant_central_atoms = graph.neighbour_list[0][relevant_edge_indexes] edge_pairs = [] for i in range(number_of_atoms(graph)): mask = relevant_central_atoms == i masked_edge_indexes = relevant_edge_indexes[mask] # number of edges of distance <= three_body_cutoff # that have i as a central atom N = masked_edge_indexes.shape[0] _idx = torch.cartesian_prod( torch.arange(N), torch.arange(N), ) # (N**2, 2) _idx = _idx[_idx[:, 0] != _idx[:, 1]] # (N**2 - N, 2) pairs_for_i = masked_edge_indexes[_idx] edge_pairs.append(pairs_for_i) return torch.cat(edge_pairs)
def triplet_edges( graph: AtomicGraph, three_body_cutoff: float, ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: """ Finds all ``Y`` triplets ``(i, j, k)`` such that: * ``i, j, k`` are indices of distinct (images of) atoms within the graph * ``r_{ij} <=`` ``three_body_cutoff`` * ``r_{ik} <=`` ``three_body_cutoff`` Returns ------- i: torch.Tensor The central atom indices, shape ``(Y,)``. j: torch.Tensor The first paired atom indices, shape ``(Y,)``. k: torch.Tensor The second paired atom indices, shape ``(Y,)``. r_ij: torch.Tensor The bond length :math:`r_{ij}`, shape ``(Y,)``. r_ik: torch.Tensor The bond length :math:`r_{ik}`, shape ``(Y,)``. r_jk: torch.Tensor The bond length :math:`r_{jk}`, shape ``(Y,)``. """ ij_ik = triplet_edge_pairs(graph, three_body_cutoff) ij = graph.neighbour_list[:, ij_ik[:, 0]] ik = graph.neighbour_list[:, ij_ik[:, 1]] shifts_ij = graph.neighbour_cell_offsets[ij_ik[:, 0]] shifts_ik = graph.neighbour_cell_offsets[ij_ik[:, 1]] v_ij = get_vectors(graph, i=ij[0, :], j=ij[1, :], shifts=shifts_ij) v_ik = get_vectors(graph, i=ik[0, :], j=ik[1, :], shifts=shifts_ik) v_jk = v_ik - v_ij r_ij = torch.norm(v_ij, dim=-1) r_ik = torch.norm(v_ik, dim=-1) r_jk = torch.norm(v_jk, dim=-1) return ( ij[0, :], ij[1, :], ik[1, :], r_ij, r_ik, r_jk, )