-
Notifications
You must be signed in to change notification settings - Fork 563
Description
Background
ATen device capture is an attempt to produce MLIR from tracing PyTorch program running on CPU. It largely modeled after torch/xla in its very first version contributed by @stephenneuendorffer from Xilinx.
Old API (pseudo-device, similar to torch/xla):
import npcomp.frontends.pytorch as torch_mlir
dev = torch_mlir.mlir_device()
t0 = torch.randn((4,4), device=dev)
t1 = torch.randn((4,4)).to(dev)
t2 = t0 + t1
t2_mlir = torch_mlir.get_mlir( t2 )
t2_cpu = t2.to('cpu')
The device API is still recommended in the official tutorial for adding backend support.
@stellaraccident refactored (and greatly reduce the LOC from ~10k to <1k) the code to utilize the c10 dispatcher. Stella also changed the user-facing API from pseudo device to a Python context associated with local dispatch key.
New API:
import torch
import torch_mlir
lhs = torch.rand(2, 3)
rhs = torch.rand(3, 4)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("mm", [lhs, rhs]) as f:
result = torch.mm(lhs, rhs)
f.returns([result])
mb.module.operation.print()
Output:
module {
func @mm(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[3,4]:f32>) -> !numpy.ndarray<[2,4]:f32> {
%0 = torch.kernel_call "aten::mm" %arg0, %arg1 : (!numpy.ndarray<[2,3]:f32>, !numpy.ndarray<[3,4]:f32>) -> !numpy.ndarray<[2,4]:f32> {sigArgTypes = ["Tensor", "Tensor"], sigIsMutable = false, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
return %0 : !numpy.ndarray<[2,4]:f32>
}
}More examples could be found here: https://github.com/llvm/mlir-npcomp/tree/main/frontends/pytorch/test/acap_export
Under the hood of the new API, c10 dispatcher picks up a local dispatch key associated with mb.capture_function context manager, and a backend fallback function picks up the boxed kernel call on torch::jit::stack and produces generic MLIR (1:1 mapping to boxed kernel call). It then re-dispatch to CPU backend to get the shape and dtype of ATen tensors.
Despite that, the c10 dispatcher has some caveats and we suspect that the it is not really how the mode-based backend fallback mechanism is supposed to be used, notably on convolution, copy_ and factory functions like arange. For us it also record the shape and dtype too early (more on that later to support dynamic shape) but this could be altered.
Design goal
The reason we look into MLIR is two-fold. XLA HLO IR has some nice properties, but it does not support dynamic shape (except for padding) and MHLO will probably fix that in the future. In the meanwhile we would like to plug in backend specific intrinsics for some of the custom ops (notably torchvision ops, or ctc loss in seq2seq models) when migrating existing PyTorch users, and we are looking into better alternatives to XLA custom call.
Xilinx and some of the custom training ASIC vendors (AFAIK) are also moving towards MLIR for easier interoperability between frameworks and their software/hardware stack. They are satisfied to export MLIR some where in the front end stack.
Proposed changes
Option 1
Export HLO graph from torch/xla and translate into MHLO. Not really what we look into for adding dynamic shape and custom ops support as it does not enhance the expressiveness of current torch/xla frontend.
Option 2
Adding back the pseudo-device API in mlir-npcomp/acap to probably workaround the caveats. It would be similar to the wrapper-based backend fallback in contrast to mode-based backend fallback. It retains the CLOC advantage while switching to (IMHO) canonical extension points for backends. We also would like to re-visit the shape and dtype static inference vs runtime tracing to support dynamic shape.
Option 3
Migrate torch/xla frontend to MLIR (some overlap with option 2), e.g. adding an XLA dialect and directly go to MHLO (probably mixed with custom ops untouched). Basically a re-write of current XLA type dispatch using backend fallback, and XLA IR into MLIR dialect (probably using chlo which provides similar functionalities to XlaBuilder). Different fallback strategy could be employed: either using "copy tensors back to CPU and call CPU kernels" similar to what AtenXlaTypeDefault does today, or plugin backend specific intrinsics (require support from runtime side).
Discussion
Option 1 could be a quick shot and doable in 2-4 weeks. I am leaning into option 2 in the mid-term (1-2 month?) and gradually move to option 3 in the long term (6-12 months) and it does enhance the UX of torch/xla. I guess the XLA/TPU stack is also moving to MLIR so we are looking for some early feedbacks here.
PS: great thanks to @silvasean for initiating and coordinating the discussion. Overall roadmap of mlir-npcomp could be found here: https://github.com/llvm/mlir-npcomp/blob/main/docs/roadmap.md