SymTorch logo

Welcome

SymTorch is a library for fast, optimisable, symbolic expressions in vanilla PyTorch.

Quickstart

Use the symtorchify() function to create a PyTorch model for the function \(f(x, y) = x^2 + 2y\).

>>> import torch
>>> from symtorch import symtorchify
>>> # Create a symbolic expression
>>> expr = symtorchify("x**2 + 2*y")
>>> # evaluate the expression at some input values
>>> input = {
...     "x": torch.tensor([1.0, 2.0]),
...     "y": torch.tensor([3.0, 4.0])
... }
>>> output = expr(input)
>>> output
tensor([ 7., 12.])
>>> # the expression has a single torch.nn.Parameter:
>>> next(expr.parameters())
Parameter containing:
tensor(2)

Use the SymbolAssignment class to assign inputs to the parameters of the expression, allowing for drop-in replacement of existing PyTorch layers with fully symbolic counterparts.

>>> from symtorch import SymbolAssignment
>>> assignment = SymbolAssignment(["x", "y"])
>>> model = torch.nn.Sequential(assignment, expr)
>>> # evaluate the model at the same input values as before,
>>> # in a manner compatible with e.g. a torch.nn.Linear layer
>>> model(torch.tensor([[1.0, 2.0], [3.0, 4.0]]))
tensor([ 7., 12.])

For more examples, including TorchScript compilation and model de/serialization, see the Usage section.

Installation

You can install SymTorch using pip:

pip install symtorch

SymTorch is compatible with Python 3.8+, and all tested versions of PyTorch and SymPy.

What about SymPyTorch?

This package attempts to supersede the amazing Patrick Kidger’s original SymPyTorch. Useful feature improvements here are:

  • Implementations of state_dict and load_state_dict for all SymTorch objects, allowing for automated saving and loading via the native PyTorch mechanisms

  • Plays nicely with TorchScript, allowing for integration into C++ code

  • A SymbolAssignment helper class to enable “drag-and-drop” replacement of existing NN components with symbolic ones:

>>> model = nn.Sequential(
...     SymbolAssignment(["a", "b"]),
...     symtorchify("3*a + b")
... )
>>> model(torch.tensor([[1, 2], [3, 4]]))
tensor([[ 5.],
        [13.]], grad_fn=<AddBackward0>)