• Tutorials >
  • (beta) Utilizing Torch Function modes with torch.compile
Shortcuts

(beta) Utilizing Torch Function modes with torch.compile

Author: Michael Lazos

This recipe covers how to use a key torch extensibility point,

torch function modes, in tandem with torch.compile to override the behavior of torch operators, also know as ops, at trace time, with no runtime overhead.

Note

This recipe requires PyTorch 2.7.0 or later.

Rewriting a torch op (torch.add -> torch.mul)

For this example, we’ll use torch function modes to rewrite occurences of addition with multiply instead. This type of override can be common if a certain backend has a custom implementation that should be dispatched for a given op.

import torch

# exit cleanly if we are on a device that doesn't support ``torch.compile``
if torch.cuda.get_device_capability() < (7, 0):
    print("Exiting because torch.compile is not supported on this device.")
    import sys
    sys.exit(0)

from torch.overrides import BaseTorchFunctionMode

# Define our mode, Note: ``BaseTorchFunctionMode``
# implements the actual invocation of func(..)
class AddToMultiplyMode(BaseTorchFunctionMode):
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if func == torch.Tensor.add:
            func = torch.mul

        return super().__torch_function__(func, types, args, kwargs)

@torch.compile()
def test_fn(x, y):
    return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)

with AddToMultiplyMode():
    z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

# The mode can also be used within the compiled region as well like this:

@torch.compile()
def test_fn(x, y):
    with AddToMultiplyMode():
        return x + y * x # Note: infix operators map to torch.Tensor.* methods

x = torch.rand(2, 2)
y = torch.rand_like(x)
z = test_fn(x, y)

assert torch.allclose(z, x * y * x)

Conclusion

In this recipe we demonstrated how to override the behavior of torch.* operators using torch function modes from within torch.compile. This enables users to utilize the extensibility benefits of torch function modes without the runtime overhead of calling torch function on every op invocation.

Total running time of the script: ( 0 minutes 5.963 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources