-
Notifications
You must be signed in to change notification settings - Fork 88
RFC-0012: Functional lazy traces from XLA to PyTorch #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Edward Z. Yang <[email protected]>
|
also cc @wconstab |
|
It seems that torch-xla relies on XLA for shape inference, and PyTorch access the sizes/numel/dim after each op. If we strip out XLA, this probably need to come from somewhere else. I also noticed that you added |
|
@byronyi If we strip out XLA, shape inference can come from core PyTorch. Some of the helpers might need to be exposed, but they're already written. The XLA inference was just more convenient since we already had to lower to it and could avoid making changes to the core, but the shape inference results are exactly the same and we can substitute the one in core for the one provided via XLA. In terms of access to sizes/numel/dim, that works exactly the same with or without XLA, not sure what you mean. What's your concern there? Finally, lazy tensors don't depend deeply on shape inference or even meta tensors. There's a moderate amount of work to be done for a few operators which currently use static sizes, but nothing fundamental in the design assumes or requires static shapes. We'll still require static ranks, but that should be flexible enough for all practical applications of this. |
Would appreciate if you could give me a pointer here to core :)
I am just wondering if it is possible to lazily evaluate shapes of lazy tensors as well.
If shapes must be materialized before actually computing the value of lazy tensors, then shape inference seems to be a must to me. |
One such example: https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/aten/src/ATen/native/ConvUtils.h#L25. They're not centralized, but they're quite easy to find and most interesting / difficult ones tend to be exposed already.
I think you're conflating two separate aspects here:
I think I've addressed this above. Shapes don't need to be materialized, but they need to be queried with a runtime mechanism which relies on the representation the vendor picks to represent an operator. How the vendor does that is entirely outside of our control since we don't even know how the vendor will represent runtime values. We do know, though, that any operand in the graph will be a pair of value and shape, there'll be a way to associate concrete sizes from the inputs with symbolic shapes in the graph at launch time and there'll be a way to query those, so we're certain it can be done. In our
|
|
To add some of my own perspective on top of @asuhan's:
|
In the model I'm describing, this would translate to an unavoidable (even assuming an infinitely smart compiler) size query runtime call in the vendor backend. By construction it cannot propagate to any dimension size since we erase those (if the vendor chooses the dynamic mode). In other words, it'd only be present during shape checks in higher layers, but once we know those pass we erase the sizes and the vendor backend is responsible to emit the proper runtime queries. It's possible the graph would also change, but I'm not sure how frequent that would be in practice. Tracing is part of the solution and not a full solution. It'll work very well for some models by itself, but we'll definitely need to complement it with TorchScript annotations in parts of models which are vulnerable to worst aspects of tracing. However, in terms of correctness, I'm pretty sure what I'm describing is sound. |
So these size helper functions (for each and every operator in PT core) will be extracted into a central place in a new repo (for lazy tensors)? I am asking this because they seem pretty sparse to me, but maybe that is only because I am not that familiar to the codebase.
I am actually okay to occasional dynamically shaped/ranked ops as they are probably going to be handled by some fallback logic (i.e. host cpu in XLA case) anyway. But I do hope we could get static/flat trace for a spanning list of uninteresting ops, element-wise, mm, etc., so vendor backends need not to be queried for output shape for each op in the functional lazy trace. Just as a illustration, for the following wrapper A wrapper-based backend fallback to CPU gives us the following traces: Ideally the functional lazy traces need not to run the operator on CPU to get sizes and strides. Right now in torch-xla it is supplied by XLA shape inference. |
PT/XLA already handles a fraction of operators, not the full set - not all operators occur in real models, especially if we're talking performance-sensitive paths. Further, only some of those have non-trivial shape equations. I've checked most in the past and I remember finding usable helpers for nearly all of them. I agree with @ezyang that meta-tensors are probably the way to go about this given a longer time horizon, but the situation isn't dire on the short term either. I wouldn't bother to centralize the helpers but rather bring up meta tensors.
The vendor can hoist / elide unnecessary checks for sequences of element-wise operations, for example. It's just a compiler optimization pass in their back-end. We could probably offer some basic, generic tools to assist some of that, but this competes against everything else we want to get done and isn't particularly hard to implement on the vendor side either. But you're right it'd be nice to have it, I'm open-minded about it.
That would be the idea, you'd never run an operator on CPU just to get sizes and strides. A vendor could choose to punt to CPU whenever a runtime size query would be needed, but that's an implementation quality issue - a good implementation would come with a runtime (native to accelerator) to support such queries on the value representation chosen by a vendor. Think metadata similar to CPU or GPU tensors, but stored in the accelerator memory, nothing prevents that. Taken to extreme, a pure interpreter for the tensor computation graph can run on the accelerator itself (not saying that'd be a truly good option, just making an informal argument that what I'm describing is feasible). I'm assuming a general purpose core (probably quite slow in absolute terms) which can execute shape computation and run a simple memory allocator is present on the accelerator, which I think it's a fairly safe assumption for training accelerators. |
So, there are two main APIs we are envisioning how to access sizes as we add more support for structured kernels. First is the direct (but somewhat inefficient) API. Suppose that you are implementing "add" in XLA and you to compute what the output shape should be, you can write this: The add call will do all the error checking you need on the inputs, and give you an output type that says what the output sizes, dtype, etc. should be. You can then make use of this information as necessary for your error checking. (NB: the output size isn't that useful, for the reason that if you actually are implementing a lowering you will have a far more detailed understanding of what the operator does, but it might still be helpful. Second is a more pimped up version of @bdhirsh's proposal at pytorch/xla#2871 where we generate the scaffolding that calls into meta, so by the time you are writing lowerings you can assume all error checking has already happened. (We haven't really gotten that far in the design process here).
Yeah, so structured kernels are trying to make it so that you can easily get the sizes and strides from PyTorch directly, rather than assuming you already have an accurate lowering. Coverage is not very high right now but getting better! |
Thanks, I was just reading through https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md#goals and that is exactly what I am asking for. We'd love to help on this effort! |
|
There's a nice tracker issue on the subject :) pytorch/pytorch#55070 |
|
I see @asuhan pushed to pytorch/pytorch@fafb8ab and corresponding changes to torch-xla side: https://github.com/pytorch/xla/tree/asuhan/xla_ltc_plugin. Nice one! I do see a TODO on shape inference: https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/compiler/node_lowering.h#L19-L20. Do you plan to provide such facilities from core using meta tensors? |
|
@byronyi On the short term I'll just expose / extract the required helpers from core and use them to provide shape inference. Longer term, yes, I really hope we'll be doing that with meta tensors instead - it'd be much nicer. |
Generalized a bit (from xla-specific) and added some context around backend integration API.
Update Lazy Tensor RFC
|
@ezyang @asuhan We are working on similar approach to integrate our compiler backend with pytorch. It's great to see the work you are doing. Do you have timeline of when and which part will be moved into pytorch core? Also feel free to let us know if there's anything we can help. Moreover do you have plan to integrate LazyTensor together with Torchscript so we can add extra layer to scripted modules, as @hzfan mentioned in pytorch/xla#2957 (comment) ? cc @mli |
|
@yzhliu We intend to put this part into the core, the lazy_tensor_core sub-folder: https://github.com/pytorch/pytorch/tree/lazy_tensor_staging. This was derived from the "upper half" of pytorch/xla - it offers lazy tensor infrastructure which is independent on the actual backend. We have a working XLA backend: https://github.com/pytorch/xla/tree/asuhan/xla_ltc_plugin, derived from the "lower half" of pytorch/xla and I'm working on a TorchScript backend too - that'll allow us to reuse nvFuser and other TorchScript backends more easily, while providing the graph capture mechanisms and infrastructure. We have ideas to use TorchScript together with lazy tensors (NB: this is different from TorchScript as a backend mentioned above) to address the caveats - undesired loop unrolling and other similar issues. The idea would be to use it sparingly in problem areas (which have control flow) and still lean on lazy tensors for most of the model, which would maintain usability. The plan is to become ready to merge lazy_tensor_core to master by July-August, which might be ambitious but should be doable at least. There are several things which need to happen: code base reduction (we can autogenerate most nodes in csrc/ops), removing the last dependencies on absl (we have a handful of uses of StrCat and Span, which we can re-implement) etc. That being said, I'm trying to keep that branch reasonably fresh, 1-2 weeks away from master, until we can finally merge it. |
|
@asuhan Thanks for sharing the details and the estimated timeline. We'll keep eyes on the project and might bother you folks later :) as we start to make progress. |
|
Only saw this RFC now, sorry. The complications I've hit so far:
For the dtype & shape inference, I wrote a program that executes all ATen ops with different combinations of shapes/dtypes and checks which of my hand-written rules applies. The dispatch wrappers are generated automatically from this information. My code is available here: https://github.com/nunoplopes/torchy Would you guys (FB, others?) be interested in chatting about integration, directions, share learned lessons, share data from models running already, etc? |
|
@nunoplopes Thanks for reaching out! We are also continuing on the branch |
I'm building a JIT compiler for PyTorch. I use lazy tensors to assemble traces and then ship those traces to an off-the-shelf compiler, like TorchScript, for optimization. Traces are cached and reused when seen again. |
|
Do we have a timeline to merge |
|
Hi @ezyang , as mentioned before, we are building features on top of the lazy tensor support, and we noticed the branch is being actively updated. Would you mind share the plan of making the branch officially available to pytorch users? |
|
We are in the process of merging lazy tensor support into pytorch master right now. (You can already see some files landing in The actual functionality of lazy tensors won't be enabled 'by default' for users (of cpu/cuda device) this year either. Right now, lazy is a 'virtual device' meaning you have to move your tensors to 'lazy' (or 'xla') device explicitly, and use other flags to configure the hardware used by the backend. In the future we'd like to explore making a 'lazy mode' available to existing devices (e.g. cpu, cuda). |
|
@wconstab Could you elaborate a bit more about the TorchScript backend? How does users use TorchScript together with lazy tensor core? Previously I extended ltc to support TorchScript for control flow, and I am super excited to hear that we have official support for it now. |
It's still WIP, but the idea is to construct a lazy trace of 'functional ATen ops', which are then trivially convertible to torchscript IR. The lazy-traced torchscript IR is a subset of overall TS IR capabilities, since it is functional, has no controlflow, and has no python classes. The IR is then fed to a TS GraphExecutor, and used with existing TS passes/backends.
Not sure what this means. We don't support control flow in lazy tracing currently. (No specific plans to support it either, although, there is some ongoing discussion on the topic.) |
Rendered
This is not a complete RFC in the traditional sense, but reflects our current thinking related to the work @asuhan has been doing in torch_xla and how this should affect PyTorch core proper. This doc predates pytorch/xla#2854 but seeing mlir-npcomp pursuing a similar avenue I wanted to get this RFC out there so people outside of FB can have more context about what is going on in XLA internally.