pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)

Home Page:https://pytorch.org/xla

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TPU dynamo speed up on inference analysis

JackCaoG opened this issue · comments

Context:
I am running inference benchmark using dynamo bridge + torch_bench on TPU v4 single device. This thread is more to update the current info and some todos. We have done the similar benchmark in https://docs.google.com/document/d/1xXwCDdQl1n2aCaJ8Lu3qn060Hp18pwj4MELVTZ3mP4g/edit. @shunting314 has done an optimization to trace the model on XLA device instead of the cpu device which result in some better performance.

PyTorch branch:

pytorch/pytorch#88449 + some profiler code(cavet: use avg_pool instead of maxpool, this is fixed now)

XLA branch:

nightly + a patch (check #4306 (comment))

diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc
index 207c8874..fa847c0d 100755
--- a/third_party/xla_client/pjrt_computation_client.cc
+++ b/third_party/xla_client/pjrt_computation_client.cc
@@ -308,6 +308,7 @@ PjRtComputationClient::ExecuteComputation(
   std::vector<DataPtr> datas;
   datas.reserve(results.size());
   for (auto& result : results) {
+    auto status = result->GetReadyFuture().Await();
     std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);
 
     std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(

TorchBench + TorchAudio + TorchText branch

nightly

Runtime

PJRT, check https://github.com/pytorch/xla/blob/master/docs/pjrt.md

Command

XLA_HLO_DEBUG=0 XLA_IR_DEBUG=0 USE_FAKE_TENSOR=0 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --backend=torchxla_trace_once --only $MODEL -n 30

Sample profiles

gs://tpu-pytorch/tmp/dynamo_profile/dynamo_tracing/try13/
I believe this one is a resnet50
image

First part of the trace(before wait_device_ops is the lazy and the remaining is the dynamo) that lazy took some times to trace the graph before execution while dynamo's walltime is most just device execution.

Result

cpu  eval  resnet18                           1.768x p=0.00
cpu  eval  resnet50                           1.610x p=0.00
cpu  eval  resnext50_32x4d                    1.328x p=0.00
cpu  eval  alexnet                            1.261x p=0.00
cpu  eval  mobilenet_v2                       2.017x p=0.00
cpu  eval  mnasnet1_0                         1.686x p=0.00
cpu  eval  vgg16                              1.155x p=0.00
cpu  eval  BERT_pytorch                       3.502x SAME

squeezenet1_1 --> RuntimeError: Fail to extact the compiled graph because of fallback: aten::avg_pool2d=3
timm_vision_transformer --> Segmentation fault (core dumped)
geomean --> model can't find (seems like it is removed from torch bench)

TODO

  1. investigate why timm_vision_transformer crashes
  2. Enable more models on torch bench.

FYI @shunting314 @wconstab @ezyang @miladm @alanwaketan @wonjoolee95

Thanks for sharing the results, Jack.

For squeezenet1_1, we should run it if we don't use the workaround to replace max_pool2d to avg_pool2d. I've updated my PR to remove the workaround

rebuild with head and remove avg_pool2d workardound and see

cpu  eval  squeezenet1_1                      1.674x SAME
cpu  eval  timm_vision_transformer            3.138x p=0.00

VIT also passed for some reason lol.

for torchxla_trivial I run with head and

XLA_HLO_DEBUG=0 XLA_IR_DEBUG=0 USE_FAKE_TENSOR=0 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --backend=torchxla_trivial --only $MODEL_NAME -n 30

and get

cpu  eval  resnet18                           0.969x SAME
cpu  eval  resnet50                           0.999x SAME
cpu  eval  resnext50_32x4d                    1.004x SAME
cpu  eval  alexnet                            1.003x SAME
cpu  eval  mobilenet_v2                       1.027x SAME
cpu  eval  mnasnet1_0                         1.001x SAME
cpu  eval  vgg16                              0.987x SAME
cpu  eval  BERT_pytorch                       1.050x SAME
cpu  eval  squeezenet1_1                      0.998x SAME
cpu  eval  timm_vision_transformer            1.012x SAME

I think we have a good understanding of infernece+dynamo bridge, close this issue as complete.