TPU dynamo speed up on inference analysis
JackCaoG opened this issue · comments
Context:
I am running inference benchmark using dynamo bridge + torch_bench on TPU v4 single device. This thread is more to update the current info and some todos. We have done the similar benchmark in https://docs.google.com/document/d/1xXwCDdQl1n2aCaJ8Lu3qn060Hp18pwj4MELVTZ3mP4g/edit. @shunting314 has done an optimization to trace the model on XLA device instead of the cpu device which result in some better performance.
PyTorch branch:
pytorch/pytorch#88449 + some profiler code(cavet: use avg_pool
instead of maxpool
, this is fixed now)
XLA branch:
nightly + a patch (check #4306 (comment))
diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc
index 207c8874..fa847c0d 100755
--- a/third_party/xla_client/pjrt_computation_client.cc
+++ b/third_party/xla_client/pjrt_computation_client.cc
@@ -308,6 +308,7 @@ PjRtComputationClient::ExecuteComputation(
std::vector<DataPtr> datas;
datas.reserve(results.size());
for (auto& result : results) {
+ auto status = result->GetReadyFuture().Await();
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);
std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
TorchBench + TorchAudio + TorchText branch
nightly
Runtime
PJRT, check https://github.com/pytorch/xla/blob/master/docs/pjrt.md
Command
XLA_HLO_DEBUG=0 XLA_IR_DEBUG=0 USE_FAKE_TENSOR=0 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --backend=torchxla_trace_once --only $MODEL -n 30
Sample profiles
gs://tpu-pytorch/tmp/dynamo_profile/dynamo_tracing/try13/
I believe this one is a resnet50
First part of the trace(before wait_device_ops
is the lazy and the remaining is the dynamo) that lazy took some times to trace the graph before execution while dynamo's walltime is most just device execution.
Result
cpu eval resnet18 1.768x p=0.00
cpu eval resnet50 1.610x p=0.00
cpu eval resnext50_32x4d 1.328x p=0.00
cpu eval alexnet 1.261x p=0.00
cpu eval mobilenet_v2 2.017x p=0.00
cpu eval mnasnet1_0 1.686x p=0.00
cpu eval vgg16 1.155x p=0.00
cpu eval BERT_pytorch 3.502x SAME
squeezenet1_1 --> RuntimeError: Fail to extact the compiled graph because of fallback: aten::avg_pool2d=3
timm_vision_transformer --> Segmentation fault (core dumped)
geomean --> model can't find
(seems like it is removed from torch bench)
TODO
- investigate why
timm_vision_transformer
crashes - Enable more models on torch bench.
FYI @shunting314 @wconstab @ezyang @miladm @alanwaketan @wonjoolee95
Thanks for sharing the results, Jack.
For squeezenet1_1, we should run it if we don't use the workaround to replace max_pool2d to avg_pool2d. I've updated my PR to remove the workaround
rebuild with head and remove avg_pool2d workardound and see
cpu eval squeezenet1_1 1.674x SAME
cpu eval timm_vision_transformer 3.138x p=0.00
VIT also passed for some reason lol.
for torchxla_trivial
I run with head and
XLA_HLO_DEBUG=0 XLA_IR_DEBUG=0 USE_FAKE_TENSOR=0 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --backend=torchxla_trivial --only $MODEL_NAME -n 30
and get
cpu eval resnet18 0.969x SAME
cpu eval resnet50 0.999x SAME
cpu eval resnext50_32x4d 1.004x SAME
cpu eval alexnet 1.003x SAME
cpu eval mobilenet_v2 1.027x SAME
cpu eval mnasnet1_0 1.001x SAME
cpu eval vgg16 0.987x SAME
cpu eval BERT_pytorch 1.050x SAME
cpu eval squeezenet1_1 0.998x SAME
cpu eval timm_vision_transformer 1.012x SAME
I think we have a good understanding of infernece+dynamo bridge, close this issue as complete.