TPU dynamo speed up on Training analysis
JackCaoG opened this issue · comments
Context:
I am running training benchmark using dynamo bridge + torch_bench on TPU v4 single device. This thread is more to update the current info and some todos.
title | value |
---|---|
PyTorch Version | Nightly at 01/2023 |
PyTorch/XLA Version | Nightly at 01/2023 |
TorchBench + TorchText + TorchAudio Version | Nightly at 01/2023 |
Accelerator | v4-8 |
Runtime | PJRT |
Average Speedup | 0.8315 |
Result
Sample Profile
gs://tpu-pytorch/tmp/dynamo_profile/dynamo_tracing_train/try1/
Device Executation Time(ms) | |
---|---|
Lazy | 17.85 |
Dynamo Input Sync | 0.083 |
Dynamo Forward | 9.6 |
Dynamo Cleanup 1 | 0.6 |
Dynamo Backward | 13.1 |
Dynamo Cleanup 2 | 0.289 |
Dynamo Combined | 23.672 |
Tracing TIme(ms) | Total Device Execution Time(ms) | Wall time(ms) | |
---|---|---|---|
Lazy | 13.67 | 17.85 | 39 |
Dynamo | N/A | 23.672 (0.75x) | 52.6 (0.74x) |
The Lazy Tensor causes the first Device execution. The lazy tracing takes 13.67ms and the device execution takes 17.85 ms, LTC’s walltime is ~39ms. This is expected since each step_marker
and ExecuteGraph
has some overhead.
If we look at the Dynamo graph, there are 5 graphs being executed. Only the forward and backward takes more than 1 ms, but there is some very noticeable gap(9.91 ms between forward and cleanup1 and 5.137 between cleanup 1 and backward) between each device execution. The key to improve the dynamo training speed would be to understand and minimize the gap between each device executions.
Going forward I think there are three areas we should investigate
- What’s the cause of the “huge” gap between device executions, was it introduced by Pytorch(AOTAutograd) or PyTorch/XLA.
- Is the overhead introduced by AOTAutograd fixed or will scale up with the model size.
- How to enhance Torch Dynamo to omit a single graph for each step.
Note
This setup is more advantageous for the dynamo because the model we serve is relatively small which makes tracing cost more significant.
We only run a single step each time in this benchmark but in the real world we can run multiple steps which can easily hide the tracing cost from execution. Lazy will have much better performance in that scenario. Dynamo needs to outperform lazy significantly to have better performance in real world use cases.
As a next step, I want to investigate how much overhead is caused by the PJRTComputationClient::ExecuteComputation
, this is the host side call to trigger the device execution. If I Look at the profile, it will looks something like
PJRT Time | Device Execution Time(ms) | overhead |
---|---|---|
21.831 | 17.832 | 3.999 |
1.892 | 0.08347 | 1.809 |
13.902 | 9.627 | 4.275 |
2.429 | 0.118 | 2.311 |
16.698 | 13.082 | 3.616 |
3.507 | 0.289 | 3.218 |
Note that First execution is caused by the lazy and the remaining 5 are caused by the dynamo. The overhead(PJRT - Device Execution) is somewhat fixed. By executing 5 executions instead of 1, dynamo incur ~11ms more overhead just by execution. The take away is we really need to reduce the number of execution triggered by dynamo. My next step is to figure out why we execute so many graphs for dynamo.
Update:
I was able to lower the per step execution from 5 to 3(forward + backward + a long sequence copy_from from aotautograd) for dynamo, need to clean up the pr a bit and rerun some benchmarks.
With the latest change in #4523, I am seeing a much more promising result
<style type="text/css"></style>
model | Old Speed Up | new Speed up |
---|---|---|
resnet50 | 0.758 | 0.937 |
resnet18 | 0.66 | 1.003 |
BERT_pytorch | 1.441 | 1.869 |
resnext50_32x4d | 0.87 | 1.139 |
alexnet | 0.632 | 0.802 |
mobilenet_v2 | 0.549 | 0.672 |
mnasnet1_0 | 0.698 | 0.967 |
vgg16 | 0.712 | 0.742 |
timm_vision_transformer | 1.275 | 1.69 |
squeezenet1_1 | 0.72 | 0.958 |
Avg | 0.8315 | 1.0779 |
did a new run with today's nighty(with some patch which I will try to merge soon).
Command I used is
:/src/pytorch# python benchmarks/dynamo/torchbench.py --randomize-input --performance --training --trace-on-xla --backend=openxla --only ${model}
model | Old Speed Up | new Speed up |
---|---|---|
resnet50 | 0.937 | 1.344 |
resnet18 | 1.003 | 1.322 |
BERT_pytorch | 1.869 | 3.356 |
resnext50_32x4d | 1.139 | 1.417 |
alexnet | 0.802 | 1.264 |
mobilenet_v2 | 0.672 | 1.408 |
mnasnet1_0 | 0.967 | 1.232 |
vgg16 | 0.742 | 0.821 |
timm_vision_transformer | 1.69 | 1.740 |
squeezenet1_1 | 0.958 | 1.563 |
Avg | 1.0779 | 1.516 |
It seems like dynamo training number got a lot better when there is lazy tensor does not overlap each steps. In real world training case we did some resnet50 training but found dynamo still slow down training quite a lot. I need to look into the profile and see what's the real status here.