pytorch / torchtitan

A native PyTorch Library for large model training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FSDP2 incur higher GPU memory usage in 2D compare to FSDP1

wanchaol opened this issue · comments

Recently found that when training in 2D, FSDP2 incurs much higher memory usage than FSDP1, triggering OOM issue for 70b model.

Some local test on H100 devgpu: llama 13B, global batch size 16 (local batch size 8), 2 way DP 4 way TP, selective op AC, shows a memory increase from 71G - 87G (16GB regression):
FSDP1:

Screenshot 2024-04-03 at 9 41 00 PM

FSDP2:
Screenshot 2024-04-03 at 9 41 18 PM

Update: to remove complicating factors, if we use full AC instead of selective op AC, we would get the same regression.

The issue is that DTensor's async funcols have recordStream called on the collective tensors, holding onto their memory longer than they should. FSDP1's CPU rate limiter implicitly mitigated the recordStream issues from TP, but FSDP2 does not have this CPU rate limiter anymore.

If we run the job with TORCH_NCCL_AVOID_RECORD_STREAMS=1, then we see FSDP2's 2D use 68.97 GiB for the Llama-13B selective AC setup.

[rank0]:2024-04-04 08:02:19,445 - root - INFO - step:  1  loss: 10.8906  memory: 56.32GiB(59.26%)  wps: 138  mfu: 1.23%
[rank0]:2024-04-04 08:02:19,445 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-04-04 08:02:54,632 - root - INFO - step: 10  loss:  9.4660  memory: 68.97GiB(72.57%)  wps: 2,095  mfu: 18.68%
[rank0]:2024-04-04 08:03:24,819 - root - INFO - step: 20  loss:  7.9382  memory: 68.97GiB(72.57%)  wps: 2,714  mfu: 24.19%
[rank0]:2024-04-04 08:03:55,005 - root - INFO - step: 30  loss:  7.1089  memory: 68.97GiB(72.57%)  wps: 2,714  mfu: 24.19%
[rank0]:2024-04-04 08:04:25,228 - root - INFO - step: 40  loss:  6.6824  memory: 68.97GiB(72.57%)  wps: 2,711  mfu: 24.16%
[rank0]:2024-04-04 08:04:55,514 - root - INFO - step: 50  loss:  6.7918  memory: 68.97GiB(72.57%)  wps: 2,705  mfu: 24.11%
[rank0]:2024-04-04 08:05:25,759 - root - INFO - step: 60  loss:  6.5128  memory: 68.97GiB(72.57%)  wps: 2,709  mfu: 24.15%
[rank0]:2024-04-04 08:05:55,993 - root - INFO - step: 70  loss:  6.2168  memory: 68.97GiB(72.57%)  wps: 2,710  mfu: 24.15%
[rank0]:2024-04-04 08:06:26,297 - root - INFO - step: 80  loss:  6.0477  memory: 68.97GiB(72.57%)  wps: 2,703  mfu: 24.10%
[rank0]:2024-04-04 08:06:56,527 - root - INFO - step: 90  loss:  5.9257  memory: 68.97GiB(72.57%)  wps: 2,710  mfu: 24.16%
[rank0]:2024-04-04 08:07:27,035 - root - INFO - step: 100  loss:  5.8014  memory: 68.97GiB(72.57%)  wps: 2,685  mfu: 23.94%
[rank0]:2024-04-04 08:08:02,221 - root - INFO - step: 110  loss:  5.7214  memory: 68.97GiB(72.57%)  wps: 2,328  mfu: 20.75%
[rank0]:2024-04-04 08:08:32,497 - root - INFO - step: 120  loss:  5.6474  memory: 68.97GiB(72.57%)  wps: 2,706  mfu: 24.12%
[rank0]:2024-04-04 08:09:02,770 - root - INFO - step: 130  loss:  5.6189  memory: 68.97GiB(72.57%)  wps: 2,706  mfu: 24.12%
[rank0]:2024-04-04 08:09:33,034 - root - INFO - step: 140  loss:  5.6629  memory: 68.97GiB(72.57%)  wps: 2,707  mfu: 24.13%
[rank0]:2024-04-04 08:10:03,782 - root - INFO - step: 150  loss:  5.5157  memory: 68.97GiB(72.57%)  wps: 2,664  mfu: 23.75%
[rank0]:2024-04-04 08:10:34,085 - root - INFO - step: 160  loss:  5.4447  memory: 68.97GiB(72.57%)  wps: 2,703  mfu: 24.10%

The MFU numbers look slightly lower though 😢

Closing as a dup of #208

@gnadathur Sorry for the confusion from similar title. These are not duplicates. This one is for GPU memory, and the other is CPU memory 😅