Note
Click here to download the full example code
(beta) Utilizing Torch Function modes with torch.compile¶
Author: Michael Lazos
- This recipe covers how to use a key torch extensibility point,
torch function modes, in tandem with
torch.compile
to override the behavior of torch operators, also know as ops, at trace time, with no runtime overhead.
Note
This recipe requires PyTorch 2.7.0 or later.
Rewriting a torch op (torch.add -> torch.mul)¶
For this example, we’ll use torch function modes to rewrite occurences of addition with multiply instead. This type of override can be common if a certain backend has a custom implementation that should be dispatched for a given op.
import torch
# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
print("Exiting because torch.compile is not supported on this device.")
import sys
sys.exit(0)
from torch.overrides import BaseTorchFunctionMode
# Define our mode, Note: ``BaseTorchFunctionMode``
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if func == torch.Tensor.add:
func = torch.mul
return super().__torch_function__(func, types, args, kwargs)
@torch.compile()
def test_fn(x, y):
return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2)
y = torch.rand_like(x)
with AddToMultiplyMode():
z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
# The mode can also be used within the compiled region as well like this:
@torch.compile()
def test_fn(x, y):
with AddToMultiplyMode():
return x + y * x # Note: infix operators map to torch.Tensor.* methods
x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)
assert torch.allclose(z, x * y * x)
Conclusion¶
In this recipe we demonstrated how to override the behavior of torch.*
operators
using torch function modes from within torch.compile
. This enables users to utilize
the extensibility benefits of torch function modes without the runtime overhead
of calling torch function on every op invocation.
See Extending Torch API with Modes for other examples and background on Torch Function modes.
Total running time of the script: ( 0 minutes 5.963 seconds)