PyTensor graph rewrites from scratch#

Manipulating nodes directly#

This section walks through the low level details of PyTensor graph manipulation. Users are not supposed to work or even be aware of these details, but it may be helpful for developers. We start with very bad practices and move on towards the right way of doing rewrites.

  • Graph structures is a required precursor to this guide

  • Graph rewriting provides the user-level summary of what is covered in here. Feel free to revisit once you’re done here.

As described in Graph structures, PyTensor graphs are composed of sequences Apply nodes, which link Variables that form the inputs and outputs of a computational Operation.

The list of inputs of an Apply node can be changed inplace to modify the computational path that leads to it. Consider the following simple example:

%env PYTENSOR_FLAGS=cxx=""
env: PYTENSOR_FLAGS=cxx=""
import pytensor
import pytensor.tensor as pt

x = pt.scalar("x")
y = pt.log(1 + x)
out = y * 2
pytensor.dprint(out, id_type="");
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Mul
 ├─ Log
 │  └─ Add
 │     ├─ 1
 │     └─ x
 └─ 2

A standard rewrite replaces pt.log(1 + x) by the more stable form pt.log1p(x). We can do this by changing the inputs of the out node inplace.

out.owner.inputs[0] = pt.log1p(x)
pytensor.dprint(out, id_type="");
Mul
 ├─ Log1p
 │  └─ x
 └─ 2

There are two problems with this direct approach:

  1. We are modifying variables in place

  2. We have to know which nodes have as input the variable we want to replace

Point 1. is important because some rewrites are “destructive” and the user may want to reuse the same graph in multiple functions.

Point 2. is important because it forces us to shift the focus of attention from the operation we want to rewrite to the variables where the operation is used. It also risks unneccessary duplication of variables, if we perform the same replacement independently for each use. This could make graph rewriting consideraby slower!

PyTensor makes use of FunctionGraphs to solve these two issues. By default, a FunctionGraph will clone all the variables between the inputs and outputs, so that the corresponding graph can be rewritten. In addition, it will create a clients dictionary that maps all the variables to the nodes where they are used.

Let’s see how we can use a FunctionGraph to achieve the same rewrite:

from pytensor.graph import FunctionGraph

x = pt.scalar("x")
y = pt.log(1 + x)
out1 = y * 2
out2 = 2 / y

# Create an empty dictionary which FunctionGraph will populate
# with the mappings from old variables to cloned ones
memo = {}
fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)
fg_x = memo[x]
fg_y = memo[y]
print("Before:\n")
pytensor.dprint(fg.outputs)

# Create expression of interest with cloned variables
fg_y_repl = pt.log1p(fg_x)

# Update all uses of old variable to new one
# Each entry in the clients dictionary, 
# contains a node and the input index where the variable is used
# Note: Some variables could be used multiple times in a single node
for client, idx in fg.clients[fg_y]:
    client.inputs[idx] = fg_y_repl
    
print("\nAfter:\n")
pytensor.dprint(fg.outputs);
Before:

Mul [id A]
 ├─ Log [id B]
 │  └─ Add [id C]
 │     ├─ 1 [id D]
 │     └─ x [id E]
 └─ 2 [id F]
True_div [id G]
 ├─ 2 [id H]
 └─ Log [id B]
    └─ ···

After:

Mul [id A]
 ├─ Log1p [id B]
 │  └─ x [id C]
 └─ 2 [id D]
True_div [id E]
 ├─ 2 [id F]
 └─ Log1p [id B]
    └─ ···

We can see that both uses of log(1 + x) were replaced by the new log1p(x).

It would probably be a good idea to update the clients dictionary if we wanted to perform another rewrite.

There are a couple of other variables in the FunctionGraph that we would also want to update, but there is no point to doing all this bookeeping manually. FunctionGraph offers a replace method that takes care of all this for the user.

# We didn't modify the variables in place so we can just reuse them!
memo = {}
fg = FunctionGraph([x], [out1, out2], clone=True, memo=memo)
fg_x = memo[x]
fg_y = memo[y]
print("Before:\n")
pytensor.dprint(fg.outputs)

# Create expression of interest with cloned variables
fg_y_repl = pt.log1p(fg_x)
fg.replace(fg_y, fg_y_repl)
    
print("\nAfter:\n")
pytensor.dprint(fg.outputs);
Before:

Mul [id A]
 ├─ Log [id B]
 │  └─ Add [id C]
 │     ├─ 1 [id D]
 │     └─ x [id E]
 └─ 2 [id F]
True_div [id G]
 ├─ 2 [id H]
 └─ Log [id B]
    └─ ···

After:

Mul [id A]
 ├─ Log1p [id B]
 │  └─ x [id C]
 └─ 2 [id D]
True_div [id E]
 ├─ 2 [id F]
 └─ Log1p [id B]
    └─ ···

There is still one big limitation with this approach. We have to know in advance “where” the variable we want to replace is present. It also doesn’t scale to multiple instances of the same pattern.

A more sensible approach would be to iterate over the nodes in the FunctionGraph and apply the rewrite wherever log(1 + x) may be present.

To keep code organized we will create a function that takes as input a node and returns a valid replacement.

from pytensor.graph import Constant

def local_log1p(node):
    # Check that this node is a Log op
    if node.op != pt.log:
        return None
    
    # Check that the input is another node (it could be an input variable)
    add_node = node.inputs[0].owner
    if add_node is None:
        return None
    
    # Check that the input to this node is an Add op
    # with 2 inputs (Add can have more inputs)
    if add_node.op != pt.add or len(add_node.inputs) != 2:
        return None
    
    # Check wether we have add(1, y) or add(x, 1)
    [x, y] = add_node.inputs
    if isinstance(x, Constant) and x.data == 1:
        return [pt.log1p(y)]
    if isinstance(y, Constant) and y.data == 1:
        return [pt.log1p(x)]

    return None
# We no longer need the memo, because our rewrite works with the node information
fg = FunctionGraph([x], [out1, out2], clone=True)

# Toposort gives a list of all nodes in a graph in topological order
# The strategy of iteration can be important when we are dealing with multiple rewrites
for node in fg.toposort():
    repl = local_log1p(node)
    if repl is None:
        continue
    # We should get one replacement of each output of the node
    assert len(repl) == len(node.outputs)
    # We could use `fg.replace_all` to avoid this loop
    for old, new in zip(node.outputs, repl):
        fg.replace(old, new)

pytensor.dprint(fg);
Mul [id A] 1
 ├─ Log1p [id B] 0
 │  └─ x [id C]
 └─ 2 [id D]
True_div [id E] 2
 ├─ 2 [id F]
 └─ Log1p [id B] 0
    └─ ···

This is starting to look much more scalable!

We are still reinventing may wheels that already exist in PyTensor, but we’re getting there. Before we move up the ladder of abstraction, let’s discuss two gotchas:

  1. The replacement variables should have types that are compatible with the original ones.

  2. We have to be careful about introducing circular dependencies

For 1. let’s look at a simple graph simplification, where we replace a costly operation that is ultimately multiplied by zero.

x = pt.vector("x", dtype="float32")
zero = pt.zeros(())
zero.name = "zero"
y = pt.exp(x) * zero

fg = FunctionGraph([x], [y], clone=False)
try:
    fg.replace(y, pt.zeros(()))
except TypeError as exc:
    print(f"TypeError: {exc}")
TypeError: Cannot convert Type Scalar(float64, shape=()) (of Variable Alloc.0) into Type Vector(float64, shape=(?,)). You can try to manually convert Alloc.0 into a Vector(float64, shape=(?,)).

The first achievement of a new PyTensor developer is unlocked by stumbling upon an error like that!

It’s important to keep in mind the Tensor part of PyTensor.

The problem here is that we are trying to replace the y variable which is a float32 vector by the zero variable which is a float64 scalar!

pytensor.dprint(fg.outputs, id_type="", print_type=True);
Mul <Vector(float64, shape=(?,))>
 ├─ Exp <Vector(float32, shape=(?,))>
 │  └─ x <Vector(float32, shape=(?,))>
 └─ ExpandDims{axis=0} <Vector(float64, shape=(1,))>
    └─ Alloc <Scalar(float64, shape=())> 'zero'
       └─ 0.0 <Scalar(float64, shape=())>
vector_zero = pt.zeros(x.shape)
vector_zero.name = "vector_zero"
fg.replace(y, vector_zero)
pytensor.dprint(fg.outputs, id_type="", print_type=True);
Alloc <Vector(float64, shape=(?,))> 'vector_zero'
 ├─ 0.0 <Scalar(float64, shape=())>
 └─ Subtensor{i} <Scalar(int64, shape=())>
    ├─ Shape <Vector(int64, shape=(1,))>
    │  └─ x <Vector(float32, shape=(?,))>
    └─ 0 <int64>

Now to the second (less common) gotcha. Introducing circular dependencies:

x = pt.scalar("x")
y = x + 1
y.name = "y"
z = y + 1
z.name = "z"

fg = FunctionGraph([x], [z], clone=False)
fg.replace(x, z)
pytensor.dprint(fg.outputs);
Add [id A] 'z'
 ├─ Add [id B] 'y'
 │  ├─ Add [id A] 'z'
 │  │  └─ ···
 │  └─ 1 [id C]
 └─ 1 [id D]

Oops! There is not much to say about this one, other than don’t do it!

Using graph rewriters#

from pytensor.graph.rewriting.basic import NodeRewriter

class LocalLog1pNodeRewriter(NodeRewriter):
        
    def tracks(self):
        return [pt.log]
    
    def transform(self, fgraph, node):
        return local_log1p(node)    
    
    def __str__(self):
        return "local_log1p"
    
    
local_log1p_node_rewriter = LocalLog1pNodeRewriter()

A NodeRewriter is required to implement only the transform method. As before, this method expects a node and should return a valid replacement for each output or None.

We also receive the FunctionGraph object, as some node rewriters may want to use global information to decide whether to return a replacement or not.

For example some rewrites that skip intermediate computations may not be useful if those intermediate computations are used by other variables.

The tracks optional method is very useful for filtering out “useless” rewrites. When NodeRewriters only applies to a specific rare Op it can be ignored completely when that Op is not present in the graph.

On its own, a NodeRewriter isn’t any better than what we had before. Where it becomes useful is when included inside a GraphRewriter, which will apply it to a whole FunctionGraph.

from pytensor.graph.rewriting.basic import in2out

x = pt.scalar("x")
y = pt.log(1 + x)
out = pt.exp(y)

fg = FunctionGraph([x], [out])
in2out(local_log1p_node_rewriter, name="local_log1p").rewrite(fg)

pytensor.dprint(fg.outputs);
Exp [id A]
 └─ Log1p [id B]
    └─ x [id C]

Here we used in2out() which creates a GraphRewriter (specifically a WalkingGraphRewriter) which walks from the inputs to the outputs of a FunctionGraph trying to apply whatever nodes are “registered” in it.

Wrapping simple functions in NodeRewriters is so common that PyTensor offers a decorator for it.

Let’s create a new rewrite that removes useless abs(exp(x)) -> exp(x).

from pytensor.graph.rewriting.basic import node_rewriter

@node_rewriter(tracks=[pt.abs])
def local_useless_abs_exp(fgraph, node):
    # Because of the tracks we don't need to check 
    # that `node` has a `Sign` Op.
    # We still need to check whether it's input is an `Abs` Op
    exp_node = node.inputs[0].owner
    if exp_node is None or exp_node.op != pt.exp:
        return None
    return exp_node.outputs

Another very useful helper is the PatternNodeRewriter, which allows you to specify a rewrite via “template matching”.

from pytensor.graph.rewriting.basic import PatternNodeRewriter

local_useless_abs_square = PatternNodeRewriter(
    (pt.abs, (pt.pow, "x", 2)),
    (pt.pow, "x", 2),
    name="local_useless_abs_square",
)

This is very useful for simple Elemwise rewrites, but becomes a bit cumbersome with Ops that must be parametrized everytime they are used.

x = pt.scalar("x")
y = pt.exp(x)
z = pt.abs(y)
w = pt.log(1.0 + z)
out = pt.abs(w ** 2)

fg = FunctionGraph([x], [out])
in2out_rewrite = in2out(
    local_log1p_node_rewriter, 
    local_useless_abs_exp, 
    local_useless_abs_square,
    name="custom_rewrites"
)
in2out_rewrite.rewrite(fg)

pytensor.dprint(fg.outputs);
Pow [id A]
 ├─ Log1p [id B]
 │  └─ Exp [id C]
 │     └─ x [id D]
 └─ 2 [id E]

Besides WalkingGraphRewriters, there are:

  • SequentialGraphRewriters, which apply a set of GraphRewriters sequentially

  • EquilibriumGraphRewriters which apply a set of GraphRewriters (and NodeRewriters) repeatedly until the graph stops changing.

Registering graph rewriters in a database#

Finally, at the top of the rewrite mountain, there are RewriteDatabases! These allow “querying” for subsets of rewrites registered in a database.

Most users trigger this when they change the mode of a PyTensor function mode="FAST_COMPILE" or mode="FAST_RUN", or mode="JAX" will lead to a different rewrite database query to be applied to the function before compilation.

The most relevant RewriteDatabase is called optdb and contains all the standard rewrites in PyTensor. You can manually register your GraphRewriter in it.

More often than not, you will want to register your rewrite in a pre-existing sub-database, like canonicalize, stabilize, or specialize.

from pytensor.compile.mode import optdb
optdb["canonicalize"].register(
    "local_log1p_node_rewriter",
    local_log1p_node_rewriter,
    "fast_compile",
    "fast_run",
    "custom",
)
with pytensor.config.change_flags(optimizer_verbose = True):
    fn = pytensor.function([x], out, mode="FAST_COMPILE")
    
print("")
pytensor.dprint(fn);
rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)

Abs [id A] 4
 └─ Pow [id B] 3
    ├─ Log1p [id C] 2
    │  └─ Abs [id D] 1
    │     └─ Exp [id E] 0
    │        └─ x [id F]
    └─ 2 [id G]

There’s also a decorator, register_canonicalize(), that automatically registers a NodeRewriter in one of these standard databases. (It’s placed in a weird location)

from pytensor.tensor.rewriting.basic import register_canonicalize

@register_canonicalize("custom")
@node_rewriter(tracks=[pt.abs])
def local_useless_abs_exp(fgraph, node):
    # Because of the tracks we don't need to check 
    # that `node` has a `Sign` Op.
    # We still need to check whether it's input is an `Abs` Op
    exp_node = node.inputs[0].owner
    if exp_node is None or exp_node.op != pt.exp:
        return None
    return exp_node.outputs

And you can also use the decorator directly

register_canonicalize(local_useless_abs_square, "custom")
local_useless_abs_square
with pytensor.config.change_flags(optimizer_verbose = True):
    fn = pytensor.function([x], out, mode="FAST_COMPILE")
    
print("")
pytensor.dprint(fn);
rewriting: rewrite local_useless_abs_square replaces Abs.0 of Abs(Pow.0) with Pow.0 of Pow(Log.0, 2)
rewriting: rewrite local_log1p replaces Log.0 of Log(Add.0) with Log1p.0 of Log1p(Abs.0)
rewriting: rewrite local_useless_abs_exp replaces Abs.0 of Abs(Exp.0) with Exp.0 of Exp(x)

Pow [id A] 2
 ├─ Log1p [id B] 1
 │  └─ Exp [id C] 0
 │     └─ x [id D]
 └─ 2 [id E]

And if you wanted to exclude your custom rewrites you can do it like this:

from pytensor.compile.mode import get_mode

with pytensor.config.change_flags(optimizer_verbose = True):
    fn = pytensor.function([x], out, mode=get_mode("FAST_COMPILE").excluding("custom"))
    
print("")
pytensor.dprint(fn);
rewriting: rewrite local_upcast_elemwise_constant_inputs replaces Add.0 of Add(1.0, Abs.0) with Add.0 of Add(Cast{float64}.0, Abs.0)
rewriting: rewrite constant_folding replaces Cast{float64}.0 of Cast{float64}(1.0) with 1.0 of None

Abs [id A] 5
 └─ Pow [id B] 4
    ├─ Log [id C] 3
    │  └─ Add [id D] 2
    │     ├─ 1.0 [id E]
    │     └─ Abs [id F] 1
    │        └─ Exp [id G] 0
    │           └─ x [id H]
    └─ 2 [id I]

Authors#

  • Authored by Ricardo Vieira in May 2023

References#

Watermark#

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor
Last updated: Sat Jan 11 2025

Python implementation: CPython
Python version       : 3.12.0
IPython version      : 8.31.0

pytensor: 2.26.4+16.g8be5c5323.dirty

sys     : 3.12.0 | packaged by conda-forge | (main, Oct  3 2023, 08:43:22) [GCC 12.3.0]
pytensor: 2.26.4+16.g8be5c5323.dirty

Watermark: 2.5.0

License notice#

All the notebooks in this example gallery are provided under a 3-Clause BSD License which allows modification, and redistribution for any use provided the copyright and license notices are preserved.

Citing Pytensor Examples#

To cite this notebook, please use the suggested citation below.

Important

Many notebooks are adapted from other sources: blogs, books… In such cases you should cite the original source as well.

Also remember to cite the relevant libraries used by your code.

Here is an example citation template in bibtex:

@incollection{citekey,
  author    = "<notebook authors, see above>",
  title     = "<notebook title>",
  editor    = "Pytensor Team",
  booktitle = "Pytensor Examples",
}

which once rendered could look like:

Ricardo Vieira . "PyTensor graph rewrites from scratch". In: Pytensor Examples. Ed. by Pytensor Team.