Usage

This guide provides detailed examples of how to use SymTorch in various scenarios.

Note

For a complete API reference, see the API documentation.

Basic Usage

The core function in SymTorch is symtorchify(), which converts a string or SymPy expression into a PyTorch module:

>>> from symtorch import symtorchify
>>> import torch
>>> # Create a simple expression
>>> expr = symtorchify("x**2 + 2*y")
>>> # Evaluate the expression
>>> input_data = {"x": torch.tensor([1.0, 2.0]), "y": torch.tensor([3.0, 4.0])}
>>> expr(input_data)
tensor([ 7., 12.])

By default, numeric values in the expression become trainable parameters:

>>> expr = symtorchify("2.5 * x + 1.7")
>>> list(expr.parameters())
[Parameter containing:
tensor(2.5000, requires_grad=True), Parameter containing:
tensor(1.7000, requires_grad=True)]
>>> # Make parameters non-trainable
>>> expr = symtorchify("2.5 * x + 1.7", trainable=False)
>>> list(expr.parameters())
[]

SymTorch supports a wide range of mathematical operations:

>>> expr = symtorchify("sin(x)**2 + cos(y)**2 + exp(z) / max(x, y)")
>>> input_data = {"x": torch.tensor(0.5), "y": torch.tensor(1.0), "z": torch.tensor(0.0)}
>>> result = expr(input_data)
>>> print(result)
tensor([1.5218])

Under the hood, SymTorch represents Expression as trees of other SymTorch objects:

>>> expr = symtorchify("sin(x)**2 + cos(y)**2 + exp(z) / max(x, y)")
>>> expr
sin(x)²+cos(y)²+exp(z)/max(x,y)
>>> expr.long_hand_representation()
Add(Pow(cos(Id(y)), Id(2)), Pow(sin(Id(x)), Id(2)), Mul(Pow(max(Id(x), Id(y)), Id(-1)), exp(Id(z))))

SymbolAssignment

The SymbolAssignment class allows you to use symbolic expressions as drop-in replacements for standard PyTorch layers, with no need for explicit expansions into dictionaries of named torch.Tensor objects.

>>> from symtorch import SymbolAssignment
>>> import torch.nn as nn

>>> model = nn.Sequential(
...     SymbolAssignment(["x", "y"]),
...     symtorchify("x**2 + 2*y")
... )

>>> # Use like any other PyTorch model
>>> input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
>>> output = model(input_tensor)
>>> print(output)
tensor([[ 5.],
        [13.]])

Model Serialization

SymTorch models can be saved and loaded using PyTorch’s built-in mechanisms:

>>> import torch

>>> # Create and save a model
>>> model = symtorchify("x**3 + 2*x**2 + 3*x + 4")
>>> torch.save(model, "symbolic_model.pt")
>>> # Load the model
>>> loaded_model = torch.load("symbolic_model.pt")
>>> # Verify that the loaded model works the same
>>> input_data = {"x": torch.tensor([2.0])}
>>> original_output = model(input_data)
>>> loaded_output = loaded_model(input_data)
>>> torch.allclose(original_output, loaded_output)
True

The state_dict of a SymTorch expression defines the full state of the expression:

>>> model = symtorchify("x**3 + 2*x**2 + 3*x + 4")
>>> model.state_dict()
OrderedDict([('_extra_state', ExpressionState(
    expression='x**3 + 2*x**2 + 3*x + 4',
    trainable=True,
    trainable_ints=False
))])

To load expressions from a state_dict into a model, use SymbolicExpression() and SymbolAssignment() as placeholders:

>>> # save a model
>>> saved_model = torch.nn.Sequential(
...     SymbolAssignment(["x"]),
...     symtorchify("x**3 + 2*x**2 + 3*x + 4")
... )
>>> saved_model.load_state_dict(torch.load("symbolic_model_state_dict.pt"))
>>> # some-time in the future, define a general model:
>>> empty_model = nn.Sequential(
...     SymbolAssignment(),
...     SymbolicExpression()
... )
>>> # load the saved model into the empty model
>>> empty_model.load_state_dict(saved_model.state_dict())
>>> empty_model
Sequential(
    (0): SymbolAssignment(names=['x'])
    (1): x**3 + 2*x**2 + 3*x + 4
)

TorchScript Compilation

SymTorch models can be compiled with TorchScript for improved performance and C++ integration:

>>> import torch
>>> # Create a symbolic model
>>> model = symtorchify("x**2 + 2*y")
>>> # Compile the model
>>> scripted_model = torch.jit.script(model)
>>> # Save the compiled model
>>> scripted_model.save("compiled_model.pt")
>>> # Load and use the compiled model
>>> loaded_model = torch.jit.load("compiled_model.pt")
>>> input_data = {"x": torch.tensor([3.0]), "y": torch.tensor([4.0])}
>>> result = loaded_model(input_data)
>>> print(result)
tensor([17.])