PyTensor graph rewrites from scratch#
Manipulating nodes directly#
This section walks through the low level details of PyTensor graph manipulation. Users are not supposed to work or even be aware of these details, but it may be helpful for developers. We start with very bad practices and move on towards the right way of doing rewrites.
Graph structures is a required precursor to this guide
Graph rewriting provides the user-level summary of what is covered in here. Feel free to revisit once you’re done here.
As described in Graph structures, PyTensor graphs are composed of sequences Apply
nodes, which link Variable
s
that form the inputs and outputs of a computational Op
eration.
The list of inputs of an Apply
node can be changed inplace to modify the computational path that leads to it.
Consider the following simple example:
%env PYTENSOR_FLAGS=cxx=""
env: PYTENSOR_FLAGS=cxx=""
import pytensor
import pytensor.tensor as pt
x = pt.scalar("x")
y = pt.log(1 + x)
out = y * 2
pytensor.dprint(out, id_type="");
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Mul
├─ Log
│ └─ Add
│ ├─ 1
│ └─ x
└─ 2
A standard rewrite replaces pt.log(1 + x)
by the more stable form pt.log1p(x)
.
We can do this by changing the inputs of the out
node inplace.
out.owner.inputs[0] = pt.log1p(x)
pytensor.dprint(out, id_type="");
Mul
├─ Log1p
│ └─ x
└─ 2
There are two problems with this direct approach:
We are modifying variables in place
We have to know which nodes have as input the variable we want to replace
Point 1. is important because some rewrites are “destructive” and the user may want to reuse the same graph in multiple functions.
Point 2. is important because it forces us to shift the focus of attention from the operation we want to rewrite to the variables where the operation is used. It also risks unneccessary duplication of variables, if we perform the same replacement independently for each use. This could make graph rewriting consideraby slower!
PyTensor makes use of FunctionGraph
s to solve these two issues.
By default, a FunctionGraph will clone all the variables between the inputs and outputs,
so that the corresponding graph can be rewritten.
In addition, it will create a clients dictionary that maps all the variables to the nodes where they are used.
Let’s see how we can use a FunctionGraph to achieve the same rewrite:
from pytensor.graph import FunctionGraph
x = pt.scalar("x")
y = pt.log(1 + x)
out1 = y * 2
out2 = 2 / y
# Create an empty dictionary which FunctionGraph will populate
# with the mappings from old variables to cloned ones
memo = {}
fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)
fg_x = memo[x]
fg_y = memo[y]
print("Before:\n")
pytensor.dprint(fg.outputs)
# Create expression of interest with cloned variables
fg_y_repl = pt.log1p(fg_x)
# Update all uses of old variable to new one
# Each entry in the clients dictionary,
# contains a node and the input index where the variable is used
# Note: Some variables could be used multiple times in a single node
for client, idx in fg.clients[fg_y]:
client.inputs[idx] = fg_y_repl
print("\nAfter:\n")
pytensor.dprint(fg.outputs);
Before:
Mul [id A]
├─ Log [id B]
│ └─ Add [id C]
│ ├─ 1 [id D]
│ └─ x [id E]
└─ 2 [id F]
True_div [id G]
├─ 2 [id H]
└─ Log [id B]
└─ ···
After:
Mul [id A]
├─ Log1p [id B]
│ └─ x [id C]
└─ 2 [id D]
True_div [id E]
├─ 2 [id F]
└─ Log1p [id B]
└─ ···
We can see that both uses of log(1 + x)
were replaced by the new log1p(x)
.
It would probably be a good idea to update the clients dictionary if we wanted to perform another rewrite.
There are a couple of other variables in the FunctionGraph that we would also want to update,
but there is no point to doing all this bookeeping manually.
FunctionGraph offers a replace
method that takes care of all this for the user.
# We didn't modify the variables in place so we can just reuse them!
memo = {}
fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)
fg_x = memo[x]
fg_y = memo[y]
print("Before:\n")
pytensor.dprint(fg.outputs)
# Create expression of interest with cloned variables
fg_y_repl = pt.log1p(fg_x)
fg.replace(fg_y, fg_y_repl)
print("\nAfter:\n")
pytensor.dprint(fg.outputs);
Before:
Mul [id A]
├─ Log [id B]
│ └─ Add [id C]
│ ├─ 1 [id D]
│ └─ x [id E]
└─ 2 [id F]
True_div [id G]
├─ 2 [id H]
└─ Log [id B]
└─ ···
After:
Mul [id A]
├─ Log1p [id B]
│ └─ x [id C]
└─ 2 [id D]
True_div [id E]
├─ 2 [id F]
└─ Log1p [id B]
└─ ···
There is still one big limitation with this approach. We have to know in advance “where” the variable we want to replace is present. It also doesn’t scale to multiple instances of the same pattern.
A more sensible approach would be to iterate over the nodes in the FunctionGraph
and apply the rewrite wherever log(1 + x)
may be present.
To keep code organized we will create a function that takes as input a node and returns a valid replacement.
from pytensor.graph import Constant
def local_log1p(node):
# Check that this node is a Log op
if node.op != pt.log:
return None
# Check that the input is another node (it could be an input variable)
add_node = node.inputs[0].owner
if add_node is None:
return None
# Check that the input to this node is an Add op
# with 2 inputs (Add can have more inputs)
if add_node.op != pt.add or len(add_node.inputs) != 2:
return None
# Check wether we have add(1, y) or add(x, 1)
[x, y] = add_node.inputs
if isinstance(x, Constant) and x.data == 1:
return [pt.log1p(y)]
if isinstance(y, Constant) and y.data == 1:
return [pt.log1p(x)]
return None
# We no longer need the memo, because our rewrite works with the node information
fg = FunctionGraph([x], [out1, out2], clone=True)
# Toposort gives a list of all nodes in a graph in topological order
# The strategy of iteration can be important when we are dealing with multiple rewrites
for node in fg.toposort():
repl = local_log1p(node)
if repl is None:
continue
# We should get one replacement of each output of the node
assert len(repl) == len(node.outputs)
# We could use `fg.replace_all` to avoid this loop
for old, new in zip(node.outputs, repl):
fg.replace(old, new)
pytensor.dprint(fg);
Mul [id A] 1
├─ Log1p [id B] 0
│ └─ x [id C]
└─ 2 [id D]
True_div [id E] 2
├─ 2 [id F]
└─ Log1p [id B] 0
└─ ···
This is starting to look much more scalable!
We are still reinventing may wheels that already exist in PyTensor, but we’re getting there. Before we move up the ladder of abstraction, let’s discuss two gotchas:
The replacement variables should have types that are compatible with the original ones.
We have to be careful about introducing circular dependencies
For 1. let’s look at a simple graph simplification, where we replace a costly operation that is ultimately multiplied by zero.
x = pt.vector("x", dtype="float32")
zero = pt.zeros(())
zero.name = "zero"
y = pt.exp(x) * zero
fg = FunctionGraph([x], [y], clone=False)
try:
fg.replace(y, pt.zeros(()))
except TypeError as exc:
print(f"TypeError: {exc}")
TypeError: Cannot convert Type Scalar(float64, shape=()) (of Variable Alloc.0) into Type Vector(float64, shape=(?,)). You can try to manually convert Alloc.0 into a Vector(float64, shape=(?,)).
The first achievement of a new PyTensor developer is unlocked by stumbling upon an error like that!
It’s important to keep in mind the Tensor part of PyTensor.
The problem here is that we are trying to replace the y
variable which is a float32 vector by the zero
variable which is a float64 scalar!
pytensor.dprint(fg.outputs, id_type="", print_type=True);
Mul <Vector(float64, shape=(?,))>
├─ Exp <Vector(float32, shape=(?,))>
│ └─ x <Vector(float32, shape=(?,))>
└─ ExpandDims{axis=0} <Vector(float64, shape=(1,))>
└─ Alloc <Scalar(float64, shape=())> 'zero'
└─ 0.0 <Scalar(float64, shape=())>
vector_zero = pt.zeros(x.shape)
vector_zero.name = "vector_zero"
fg.replace(y, vector_zero)
pytensor.dprint(fg.outputs, id_type="", print_type=True);
Alloc <Vector(float64, shape=(?,))> 'vector_zero'
├─ 0.0 <Scalar(float64, shape=())>
└─ Subtensor{i} <Scalar(int64, shape=())>
├─ Shape <Vector(int64, shape=(1,))>
│ └─ x <Vector(float32, shape=(?,))>
└─ 0 <int64>
Now to the second (less common) gotcha. Introducing circular dependencies:
x = pt.scalar("x")
y = x + 1
y.name = "y"
z = y + 1
z.name = "z"
fg = FunctionGraph([x], [z], clone=False)
fg.replace(x, z)
pytensor.dprint(fg.outputs);
Add [id A] 'z'
├─ Add [id B] 'y'
│ ├─ Add [id A] 'z'
│ │ └─ ···
│ └─ 1 [id C]
└─ 1 [id D]
Oops! There is not much to say about this one, other than don’t do it!
Using graph rewriters#
from pytensor.graph.rewriting.basic import NodeRewriter
class LocalLog1pNodeRewriter(NodeRewriter):
def tracks(self):
return [pt.log]
def transform(self, fgraph, node):
return local_log1p(node)
def __str__(self):
return "local_log1p"
local_log1p_node_rewriter = LocalLog1pNodeRewriter()
A NodeRewriter
is required to implement only the transform
method.
As before, this method expects a node and should return a valid replacement for each output or None
.
We also receive the FunctionGraph
object, as some node rewriters may want to use global information to decide whether to return a replacement or not.
For example some rewrites that skip intermediate computations may not be useful if those intermediate computations are used by other variables.
The tracks
optional method is very useful for filtering out “useless” rewrites. When NodeRewriter
s only applies to a specific rare Op
it can be ignored completely when that Op
is not present in the graph.
On its own, a NodeRewriter
isn’t any better than what we had before. Where it becomes useful is when included inside a GraphRewriter
, which will apply it to a whole FunctionGraph
.
from pytensor.graph.rewriting.basic import in2out
x = pt.scalar("x")
y = pt.log(1 + x)
out = pt.exp(y)
fg = FunctionGraph([x], [out])
in2out(local_log1p_node_rewriter, name="local_log1p").rewrite(fg)
pytensor.dprint(fg.outputs);
Exp [id A]
└─ Log1p [id B]
└─ x [id C]
Here we used in2out()
which creates a GraphRewriter
(specifically a WalkingGraphRewriter
) which walks from the inputs to the outputs of a FunctionGraph trying to apply whatever nodes are “registered” in it.
Wrapping simple functions in NodeRewriter
s is so common that PyTensor offers a decorator for it.
Let’s create a new rewrite that removes useless abs(exp(x)) -> exp(x)
.
from pytensor.graph.rewriting.basic import node_rewriter
@node_rewriter(tracks=[pt.abs])
def local_useless_abs_exp(fgraph, node):
# Because of the tracks we don't need to check
# that `node` has a `Sign` Op.
# We still need to check whether it's input is an `Abs` Op
exp_node = node.inputs[0].owner
if exp_node is None or exp_node.op != pt.exp:
return None
return exp_node.outputs
Another very useful helper is the PatternNodeRewriter
, which allows you to specify a rewrite via “template matching”.
from pytensor.graph.rewriting.basic import PatternNodeRewriter
local_useless_abs_square = PatternNodeRewriter(
(pt.abs, (pt.pow, "x", 2)),
(pt.pow, "x", 2),
name="local_useless_abs_square",
)
This is very useful for simple Elemwise rewrites, but becomes a bit cumbersome with Ops that must be parametrized everytime they are used.
x = pt.scalar("x")
y = pt.exp(x)
z = pt.abs(y)
w = pt.log(1.0 + z)
out = pt.abs(w ** 2)
fg = FunctionGraph([x], [out])
in2out_rewrite = in2out(
local_log1p_node_rewriter,
local_useless_abs_exp,
local_useless_abs_square,
name="custom_rewrites"
)
in2out_rewrite.rewrite(fg)
pytensor.dprint(fg.outputs);
Pow [id A]
├─ Log1p [id B]
│ └─ Exp [id C]
│ └─ x [id D]
└─ 2 [id E]
Besides WalkingGraphRewriter
s, there are:
SequentialGraphRewriter
s, which apply a set ofGraphRewriters
sequentiallyEquilibriumGraphRewriter
s which apply a set ofGraphRewriters
(andNodeRewriters
) repeatedly until the graph stops changing.
Registering graph rewriters in a database#
Finally, at the top of the rewrite mountain, there are RewriteDatabase
s! These allow “querying” for subsets of rewrites registered in a database.
Most users trigger this when they change the mode
of a PyTensor function mode="FAST_COMPILE"
or mode="FAST_RUN"
, or mode="JAX"
will lead to a different rewrite database query to be applied to the function before compilation.
The most relevant RewriteDatabase
is called optdb
and contains all the standard rewrites in PyTensor. You can manually register your GraphRewriter
in it.
More often than not, you will want to register your rewrite in a pre-existing sub-database, like canonicalize, stabilize, or specialize.
from pytensor.compile.mode import optdb
optdb["canonicalize"].register(
"local_log1p_node_rewriter",
local_log1p_node_rewriter,
"fast_compile",
"fast_run",
"custom",
)
with pytensor.config.change_flags(optimizer_verbose = True):
fn = pytensor.function([x], out, mode="FAST_COMPILE")
print("")
pytensor.dprint(fn);
rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)
Abs [id A] 4
└─ Pow [id B] 3
├─ Log1p [id C] 2
│ └─ Abs [id D] 1
│ └─ Exp [id E] 0
│ └─ x [id F]
└─ 2 [id G]
There’s also a decorator, register_canonicalize()
, that automatically registers a NodeRewriter
in one of these standard databases. (It’s placed in a weird location)
from pytensor.tensor.rewriting.basic import register_canonicalize
@register_canonicalize("custom")
@node_rewriter(tracks=[pt.abs])
def local_useless_abs_exp(fgraph, node):
# Because of the tracks we don't need to check
# that `node` has a `Sign` Op.
# We still need to check whether it's input is an `Abs` Op
exp_node = node.inputs[0].owner
if exp_node is None or exp_node.op != pt.exp:
return None
return exp_node.outputs
And you can also use the decorator directly
register_canonicalize(local_useless_abs_square, "custom")
local_useless_abs_square
with pytensor.config.change_flags(optimizer_verbose = True):
fn = pytensor.function([x], out, mode="FAST_COMPILE")
print("")
pytensor.dprint(fn);
rewriting: rewrite local_useless_abs_square replaces Abs.0 of Abs(Pow.0) with Pow.0 of Pow(Log.0, 2)
rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)
rewriting: rewrite local_useless_abs_exp replaces Abs.0 of Abs(Exp.0) with Exp.0 of Exp(x)
Pow [id A] 2
├─ Log1p [id B] 1
│ └─ Exp [id C] 0
│ └─ x [id D]
└─ 2 [id E]
And if you wanted to exclude your custom rewrites you can do it like this:
from pytensor.compile.mode import get_mode
with pytensor.config.change_flags(optimizer_verbose = True):
fn = pytensor.function([x], out, mode=get_mode("FAST_COMPILE").excluding("custom"))
print("")
pytensor.dprint(fn);
rewriting: rewrite local_upcast_elemwise_constant_inputs replaces Add.0 of Add(1.0, Abs.0) with Add.0 of Add(Cast{float64}.0, Abs.0)
rewriting: rewrite constant_folding replaces Cast{float64}.0 of Cast{float64}(1.0) with 1.0 of None
Abs [id A] 5
└─ Pow [id B] 4
├─ Log [id C] 3
│ └─ Add [id D] 2
│ ├─ 1.0 [id E]
│ └─ Abs [id F] 1
│ └─ Exp [id G] 0
│ └─ x [id H]
└─ 2 [id I]
References#
Watermark#
%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor
Last updated: Sat Jan 11 2025
Python implementation: CPython
Python version : 3.12.0
IPython version : 8.31.0
pytensor: 2.26.4+16.g8be5c5323.dirty
sys : 3.12.0 | packaged by conda-forge | (main, Oct 3 2023, 08:43:22) [GCC 12.3.0]
pytensor: 2.26.4+16.g8be5c5323.dirty
Watermark: 2.5.0
License notice#
All the notebooks in this example gallery are provided under a 3-Clause BSD License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.
Citing Pytensor Examples#
To cite this notebook, please use the suggested citation below.
Important
Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.
Also remember to cite the relevant libraries used by your code.
Here is an example citation template in bibtex:
@incollection{citekey,
author = "<notebook authors, see above>",
title = "<notebook title>",
editor = "Pytensor Team",
booktitle = "Pytensor Examples",
}
which once rendered could look like: