mosaicml / llm-foundry

LLM training code for Databricks foundation models

Home Page:https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Observing 1/2 the throughput on AMD MI250

staghado opened this issue · comments

I am benchmarking the throughput for a transformer model (GPT arch) on an MI250 and I consistently get almost 1/2 the per GPU throughput as what I get on an A100 or what is reported in the blog here.
The MI250 i'm using has only 64GB of memory but It should still get comparable throughput with an A100 with the same memory.

The relevant stack is :

flash-attn               2.0.4
pytorch-triton-rocm      2.2.0
torch                    2.2.0+rocm5.6
torch-optimizer          0.3.0
torchmetrics             1.3.0.post0
mosaicml                 0.19.0
mosaicml-cli             0.6.22

More details on the benchmark:

|  Model | Batch size |   GPU    | #GPU:s| tok/s   |tok/s/gpu |
|--------|------------|----------|-------|---------|----------|
| 7.0B   |   4        |   A100   | 4     | 14k     | 3700     |
| 7.0B   |   4        |   MI250  | 8     | 12.8k   | 1602     |
| 13.3B  |   4        |   A100   | 8     | 12.9k   | 1613     |
| 13.3B  |   4        |   MI250  | 8     | 6k      | 753      |

In the blogpost here:
Did you normalize the per GPU throughput by considerinG each MI250 as being one GPU or two GPUs ?
By default composer will detect 8 GPUs, did you manually multiply by 2 to account for the fact an MI250 is actually two cuda devices?

PyTorch detects each 1xMI250 as 2 devices = 2 GCDs (Graphics Compute Die), so there is a 2x multiplier you have to apply to get the throughput per MI250.

To be specific, the systems profiled in the blogpost were 4xMI250 systems, but PyTorch and Composer detect it as 8 devices, so when Composer reports tokens/s/gpu, it is really tokens/s/gcd, and then I had to multiply by 2x to get the throughput per MI250. Based on your table, I think you're also profiling a 4xMI250 system, is that right?

Also just fyi, this behavior is being fixed with the upcoming AMD MI300X GPUs, they come in 8xMI300X systems and will be detected as 8 devices by PyTorch.

Closing as resolved, let us know if you have any more questions!

Thanks for your answer, that solves it!