pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)

Home Page:https://pytorch.org/xla

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TPU dynamo speed up on Training analysis

JackCaoG opened this issue · comments

Context:

I am running training benchmark using dynamo bridge + torch_bench on TPU v4 single device. This thread is more to update the current info and some todos.

title value
PyTorch Version Nightly at 01/2023
PyTorch/XLA Version Nightly at 01/2023
TorchBench + TorchText + TorchAudio Version Nightly at 01/2023
Accelerator v4-8
Runtime PJRT
Average Speedup 0.8315

Result

image

Sample Profile

gs://tpu-pytorch/tmp/dynamo_profile/dynamo_tracing_train/try1/
image

  Device Executation Time(ms)
Lazy 17.85
Dynamo Input Sync 0.083
Dynamo Forward 9.6
Dynamo Cleanup 1 0.6
Dynamo Backward 13.1
Dynamo Cleanup 2 0.289
Dynamo Combined 23.672
  Tracing TIme(ms) Total Device Execution Time(ms) Wall time(ms)
Lazy 13.67 17.85 39
Dynamo N/A 23.672 (0.75x) 52.6 (0.74x)

FYI @shunting314 @wconstab @alanwaketan @wonjoolee95

The Lazy Tensor causes the first Device execution. The lazy tracing takes 13.67ms and the device execution takes 17.85 ms, LTC’s walltime is ~39ms. This is expected since each step_marker and ExecuteGraph has some overhead.

If we look at the Dynamo graph, there are 5 graphs being executed. Only the forward and backward takes more than 1 ms, but there is some very noticeable gap(9.91 ms between forward and cleanup1 and 5.137 between cleanup 1 and backward) between each device execution. The key to improve the dynamo training speed would be to understand and minimize the gap between each device executions.

Going forward I think there are three areas we should investigate

  1. What’s the cause of the “huge” gap between device executions, was it introduced by Pytorch(AOTAutograd) or PyTorch/XLA.
  2. Is the overhead introduced by AOTAutograd fixed or will scale up with the model size.
  3. How to enhance Torch Dynamo to omit a single graph for each step.

Note

This setup is more advantageous for the dynamo because the model we serve is relatively small which makes tracing cost more significant.

We only run a single step each time in this benchmark but in the real world we can run multiple steps which can easily hide the tracing cost from execution. Lazy will have much better performance in that scenario. Dynamo needs to outperform lazy significantly to have better performance in real world use cases.

As a next step, I want to investigate how much overhead is caused by the PJRTComputationClient::ExecuteComputation , this is the host side call to trigger the device execution. If I Look at the profile, it will looks something like
image

 PJRT Time Device Execution Time(ms) overhead
21.831 17.832 3.999
1.892 0.08347 1.809
13.902 9.627 4.275
2.429 0.118 2.311
16.698 13.082 3.616
3.507 0.289 3.218

Note that First execution is caused by the lazy and the remaining 5 are caused by the dynamo. The overhead(PJRT - Device Execution) is somewhat fixed. By executing 5 executions instead of 1, dynamo incur ~11ms more overhead just by execution. The take away is we really need to reduce the number of execution triggered by dynamo. My next step is to figure out why we execute so many graphs for dynamo.

Update:

I was able to lower the per step execution from 5 to 3(forward + backward + a long sequence copy_from from aotautograd) for dynamo, need to clean up the pr a bit and rerun some benchmarks.

With the latest change in #4523, I am seeing a much more promising result
<style type="text/css"></style>

model Old Speed Up new Speed up
resnet50 0.758 0.937
resnet18 0.66 1.003
BERT_pytorch 1.441 1.869
resnext50_32x4d 0.87 1.139
alexnet 0.632 0.802
mobilenet_v2 0.549 0.672
mnasnet1_0 0.698 0.967
vgg16 0.712 0.742
timm_vision_transformer 1.275 1.69
squeezenet1_1 0.72 0.958
Avg 0.8315 1.0779

did a new run with today's nighty(with some patch which I will try to merge soon).

Command I used is

:/src/pytorch# python benchmarks/dynamo/torchbench.py --randomize-input --performance --training --trace-on-xla --backend=openxla --only ${model}
model Old Speed Up new Speed up
resnet50 0.937 1.344
resnet18 1.003 1.322
BERT_pytorch 1.869 3.356
resnext50_32x4d 1.139 1.417
alexnet 0.802 1.264
mobilenet_v2 0.672 1.408
mnasnet1_0 0.967 1.232
vgg16 0.742 0.821
timm_vision_transformer 1.69 1.740
squeezenet1_1 0.958 1.563
Avg 1.0779 1.516

It seems like dynamo training number got a lot better when there is lazy tensor does not overlap each steps. In real world training case we did some resnet50 training but found dynamo still slow down training quite a lot. I need to look into the profile and see what's the real status here.