Usage¶
This guide provides detailed examples of how to use SymTorch in various scenarios.
Note
For a complete API reference, see the API documentation.
Basic Usage¶
The core function in SymTorch is symtorchify(), which converts a string or SymPy expression into a PyTorch module:
>>> from symtorch import symtorchify
>>> import torch
>>> # Create a simple expression
>>> expr = symtorchify("x**2 + 2*y")
>>> # Evaluate the expression
>>> input_data = {"x": torch.tensor([1.0, 2.0]), "y": torch.tensor([3.0, 4.0])}
>>> expr(input_data)
tensor([ 7., 12.])
By default, numeric values in the expression become trainable parameters:
>>> expr = symtorchify("2.5 * x + 1.7")
>>> list(expr.parameters())
[Parameter containing:
tensor(2.5000, requires_grad=True), Parameter containing:
tensor(1.7000, requires_grad=True)]
>>> # Make parameters non-trainable
>>> expr = symtorchify("2.5 * x + 1.7", trainable=False)
>>> list(expr.parameters())
[]
SymTorch supports a wide range of mathematical operations:
>>> expr = symtorchify("sin(x)**2 + cos(y)**2 + exp(z) / max(x, y)")
>>> input_data = {"x": torch.tensor(0.5), "y": torch.tensor(1.0), "z": torch.tensor(0.0)}
>>> result = expr(input_data)
>>> print(result)
tensor([1.5218])
Under the hood, SymTorch represents Expression as trees of other SymTorch objects:
>>> expr = symtorchify("sin(x)**2 + cos(y)**2 + exp(z) / max(x, y)")
>>> expr
sin(x)²+cos(y)²+exp(z)/max(x,y)
>>> expr.long_hand_representation()
Add(Pow(cos(Id(y)), Id(2)), Pow(sin(Id(x)), Id(2)), Mul(Pow(max(Id(x), Id(y)), Id(-1)), exp(Id(z))))
SymbolAssignment¶
The SymbolAssignment class allows you to use symbolic expressions as drop-in replacements for standard PyTorch layers, with no need for explicit expansions into
dictionaries of named torch.Tensor objects.
>>> from symtorch import SymbolAssignment
>>> import torch.nn as nn
>>> model = nn.Sequential(
... SymbolAssignment(["x", "y"]),
... symtorchify("x**2 + 2*y")
... )
>>> # Use like any other PyTorch model
>>> input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
>>> output = model(input_tensor)
>>> print(output)
tensor([[ 5.],
[13.]])
Model Serialization¶
SymTorch models can be saved and loaded using PyTorch’s built-in mechanisms:
>>> import torch
>>> # Create and save a model
>>> model = symtorchify("x**3 + 2*x**2 + 3*x + 4")
>>> torch.save(model, "symbolic_model.pt")
>>> # Load the model
>>> loaded_model = torch.load("symbolic_model.pt")
>>> # Verify that the loaded model works the same
>>> input_data = {"x": torch.tensor([2.0])}
>>> original_output = model(input_data)
>>> loaded_output = loaded_model(input_data)
>>> torch.allclose(original_output, loaded_output)
True
The state_dict of a SymTorch expression defines the full state of the expression:
>>> model = symtorchify("x**3 + 2*x**2 + 3*x + 4")
>>> model.state_dict()
OrderedDict([('_extra_state', ExpressionState(
expression='x**3 + 2*x**2 + 3*x + 4',
trainable=True,
trainable_ints=False
))])
To load expressions from a state_dict into a model, use SymbolicExpression() and SymbolAssignment() as placeholders:
>>> # save a model
>>> saved_model = torch.nn.Sequential(
... SymbolAssignment(["x"]),
... symtorchify("x**3 + 2*x**2 + 3*x + 4")
... )
>>> saved_model.load_state_dict(torch.load("symbolic_model_state_dict.pt"))
>>> # some-time in the future, define a general model:
>>> empty_model = nn.Sequential(
... SymbolAssignment(),
... SymbolicExpression()
... )
>>> # load the saved model into the empty model
>>> empty_model.load_state_dict(saved_model.state_dict())
>>> empty_model
Sequential(
(0): SymbolAssignment(names=['x'])
(1): x**3 + 2*x**2 + 3*x + 4
)
TorchScript Compilation¶
SymTorch models can be compiled with TorchScript for improved performance and C++ integration:
>>> import torch
>>> # Create a symbolic model
>>> model = symtorchify("x**2 + 2*y")
>>> # Compile the model
>>> scripted_model = torch.jit.script(model)
>>> # Save the compiled model
>>> scripted_model.save("compiled_model.pt")
>>> # Load and use the compiled model
>>> loaded_model = torch.jit.load("compiled_model.pt")
>>> input_data = {"x": torch.tensor([3.0]), "y": torch.tensor([4.0])}
>>> result = loaded_model(input_data)
>>> print(result)
tensor([17.])