Welcome¶
SymTorch is a library for fast, optimisable, symbolic expressions in vanilla PyTorch.
Quickstart¶
Use the symtorchify() function to create a PyTorch model for the function \(f(x, y) = x^2 + 2y\).
>>> import torch
>>> from symtorch import symtorchify
>>> # Create a symbolic expression
>>> expr = symtorchify("x**2 + 2*y")
>>> # evaluate the expression at some input values
>>> input = {
... "x": torch.tensor([1.0, 2.0]),
... "y": torch.tensor([3.0, 4.0])
... }
>>> output = expr(input)
>>> output
tensor([ 7., 12.])
>>> # the expression has a single torch.nn.Parameter:
>>> next(expr.parameters())
Parameter containing:
tensor(2)
Use the SymbolAssignment class to assign inputs to the parameters of the expression,
allowing for drop-in replacement of existing PyTorch layers with fully symbolic counterparts.
>>> from symtorch import SymbolAssignment
>>> assignment = SymbolAssignment(["x", "y"])
>>> model = torch.nn.Sequential(assignment, expr)
>>> # evaluate the model at the same input values as before,
>>> # in a manner compatible with e.g. a torch.nn.Linear layer
>>> model(torch.tensor([[1.0, 2.0], [3.0, 4.0]]))
tensor([ 7., 12.])
For more examples, including TorchScript compilation and model de/serialization, see the Usage section.
Installation¶
You can install SymTorch using pip:
pip install symtorch
SymTorch is compatible with Python 3.8+, and all tested versions of PyTorch and SymPy.
What about SymPyTorch?¶
This package attempts to supersede the amazing Patrick Kidger’s original SymPyTorch. Useful feature improvements here are:
Implementations of
state_dictandload_state_dictfor allSymTorchobjects, allowing for automated saving and loading via the native PyTorch mechanismsPlays nicely with TorchScript, allowing for integration into C++ code
A
SymbolAssignmenthelper class to enable “drag-and-drop” replacement of existing NN components with symbolic ones:
>>> model = nn.Sequential(
... SymbolAssignment(["a", "b"]),
... symtorchify("3*a + b")
... )
>>> model(torch.tensor([[1, 2], [3, 4]]))
tensor([[ 5.],
[13.]], grad_fn=<AddBackward0>)