mode
– controlling compilation#
Guide#
The mode
parameter to pytensor.function()
controls how the
inputs-to-outputs graph is transformed into a callable object.
PyTensor defines the following modes by name:
'FAST_COMPILE'
: Apply just a few graph rewrites and only use Python implementations.'FAST_RUN'
: Apply all rewrites, and use C implementations where possible.NUMBA
: Apply all relevant related rewrites and compile the whole graph using Numba.JAX
: Apply all relevant rewrites and compile the whole graph using JAX.PYTORCH
Apply all relevant rewrites and compile the whole graph using PyTorch compile.'DebugMode'
: A mode for debugging. See DebugMode for details.'NanGuardMode
: Nan detector'DEBUG_MODE'
: Deprecated. Use the string DebugMode.
The default mode is typically FAST_RUN
, but it can be controlled via the
configuration variable config.mode
, which can be
overridden by passing the keyword argument to pytensor.function()
.
For Numba, JAX, and PyTorch, we exclude rewrites that introduce C-only Ops, as well as BLAS optimizations, as those are done automatically by the respective backends.
For JAX we also exclude fusion and inplace optimizations, as JAX does not support them at the user level. They are performed automatically by JAX.
Todo
For a finer level of control over which rewrites are applied, and whether C or Python implementations are used, read…. what exactly?
Reference#
- class pytensor.compile.mode.Mode(object)[source]#
Compilation is controlled by two attributes: the
optimizer
controls how an expression graph will be transformed; thelinker
controls how the rewritten expression graph will be evaluated.- including(*tags)[source]#
Return a new
Mode
instance like this one, but with itsoptimizer
modified by including the given tags.