TPU dynamo speed up on inference analysis #4328
Labels
dynamo
performance
triaged
This issue has been reviewed by the triage team and the appropriate priority assigned.
Context:
I am running inference benchmark using dynamo bridge + torch_bench on TPU v4 single device. This thread is more to update the current info and some todos. We have done the similar benchmark in https://docs.google.com/document/d/1xXwCDdQl1n2aCaJ8Lu3qn060Hp18pwj4MELVTZ3mP4g/edit. @shunting314 has done an optimization to trace the model on XLA device instead of the cpu device which result in some better performance.
PyTorch branch:
pytorch/pytorch#88449 + some profiler code(cavet: use
avg_pool
instead ofmaxpool
, this is fixed now)XLA branch:
nightly + a patch (check #4306 (comment))
TorchBench + TorchAudio + TorchText branch
nightly
Runtime
PJRT, check https://github.com/pytorch/xla/blob/master/docs/pjrt.md
Command
Sample profiles
gs://tpu-pytorch/tmp/dynamo_profile/dynamo_tracing/try13/
I believe this one is a resnet50
First part of the trace(before
wait_device_ops
is the lazy and the remaining is the dynamo) that lazy took some times to trace the graph before execution while dynamo's walltime is most just device execution.Result
squeezenet1_1 -->
RuntimeError: Fail to extact the compiled graph because of fallback: aten::avg_pool2d=3
timm_vision_transformer -->
Segmentation fault (core dumped)
geomean -->
model can't find
(seems like it is removed from torch bench)TODO
timm_vision_transformer
crashesFYI @shunting314 @wconstab @ezyang @miladm @alanwaketan @wonjoolee95
The text was updated successfully, but these errors were encountered: