Usage ===== This guide provides detailed examples of how to use SymTorch in various scenarios. .. note:: For a complete API reference, see the :doc:`API documentation `. Basic Usage ----------- The core function in SymTorch is :func:`~symtorch.symtorchify`, which converts a string or SymPy expression into a PyTorch module: .. code-block:: >>> from symtorch import symtorchify >>> import torch >>> # Create a simple expression >>> expr = symtorchify("x**2 + 2*y") >>> # Evaluate the expression >>> input_data = {"x": torch.tensor([1.0, 2.0]), "y": torch.tensor([3.0, 4.0])} >>> expr(input_data) tensor([ 7., 12.]) By default, numeric values in the expression become trainable parameters: .. code-block:: >>> expr = symtorchify("2.5 * x + 1.7") >>> list(expr.parameters()) [Parameter containing: tensor(2.5000, requires_grad=True), Parameter containing: tensor(1.7000, requires_grad=True)] >>> # Make parameters non-trainable >>> expr = symtorchify("2.5 * x + 1.7", trainable=False) >>> list(expr.parameters()) [] SymTorch supports a wide range of mathematical operations: .. code-block:: >>> expr = symtorchify("sin(x)**2 + cos(y)**2 + exp(z) / max(x, y)") >>> input_data = {"x": torch.tensor(0.5), "y": torch.tensor(1.0), "z": torch.tensor(0.0)} >>> result = expr(input_data) >>> print(result) tensor([1.5218]) Under the hood, SymTorch represents :class:`~symtorch.Expression` as trees of other SymTorch objects: .. code-block:: >>> expr = symtorchify("sin(x)**2 + cos(y)**2 + exp(z) / max(x, y)") >>> expr sin(x)²+cos(y)²+exp(z)/max(x,y) >>> expr.long_hand_representation() Add(Pow(cos(Id(y)), Id(2)), Pow(sin(Id(x)), Id(2)), Mul(Pow(max(Id(x), Id(y)), Id(-1)), exp(Id(z)))) :class:`~symtorch.SymbolAssignment` ----------------------------------- The :class:`~symtorch.SymbolAssignment` class allows you to use symbolic expressions as drop-in replacements for standard PyTorch layers, with no need for explicit expansions into dictionaries of named :class:`torch.Tensor` objects. .. code-block:: >>> from symtorch import SymbolAssignment >>> import torch.nn as nn >>> model = nn.Sequential( ... SymbolAssignment(["x", "y"]), ... symtorchify("x**2 + 2*y") ... ) >>> # Use like any other PyTorch model >>> input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) >>> output = model(input_tensor) >>> print(output) tensor([[ 5.], [13.]]) Model Serialization ------------------- SymTorch models can be saved and loaded using PyTorch's built-in mechanisms: .. code-block:: >>> import torch >>> # Create and save a model >>> model = symtorchify("x**3 + 2*x**2 + 3*x + 4") >>> torch.save(model, "symbolic_model.pt") >>> # Load the model >>> loaded_model = torch.load("symbolic_model.pt") >>> # Verify that the loaded model works the same >>> input_data = {"x": torch.tensor([2.0])} >>> original_output = model(input_data) >>> loaded_output = loaded_model(input_data) >>> torch.allclose(original_output, loaded_output) True The ``state_dict`` of a SymTorch expression defines the full state of the expression: .. code-block:: >>> model = symtorchify("x**3 + 2*x**2 + 3*x + 4") >>> model.state_dict() OrderedDict([('_extra_state', ExpressionState( expression='x**3 + 2*x**2 + 3*x + 4', trainable=True, trainable_ints=False ))]) To load expressions from a ``state_dict`` into a model, use :class:`~symtorch.SymbolicExpression()` and :class:`~symtorch.SymbolAssignment()` as placeholders: .. code-block:: >>> # save a model >>> saved_model = torch.nn.Sequential( ... SymbolAssignment(["x"]), ... symtorchify("x**3 + 2*x**2 + 3*x + 4") ... ) >>> saved_model.load_state_dict(torch.load("symbolic_model_state_dict.pt")) >>> # some-time in the future, define a general model: >>> empty_model = nn.Sequential( ... SymbolAssignment(), ... SymbolicExpression() ... ) >>> # load the saved model into the empty model >>> empty_model.load_state_dict(saved_model.state_dict()) >>> empty_model Sequential( (0): SymbolAssignment(names=['x']) (1): x**3 + 2*x**2 + 3*x + 4 ) TorchScript Compilation ----------------------- SymTorch models can be compiled with TorchScript for improved performance and C++ integration: .. code-block:: >>> import torch >>> # Create a symbolic model >>> model = symtorchify("x**2 + 2*y") >>> # Compile the model >>> scripted_model = torch.jit.script(model) >>> # Save the compiled model >>> scripted_model.save("compiled_model.pt") >>> # Load and use the compiled model >>> loaded_model = torch.jit.load("compiled_model.pt") >>> input_data = {"x": torch.tensor([3.0]), "y": torch.tensor([4.0])} >>> result = loaded_model(input_data) >>> print(result) tensor([17.])