pytorch / torchtitan

A native PyTorch Library for large model training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question: tp able to run a model which not able to fit a single batch on GPU?

lucasjinreal opened this issue · comments

It can make model able to train a big model where GPU can not even fit batchsize =1?

When a single GPU cannot even fit batch size 1, depending on where the memory is coming from, any form of parallelism may be able to help (e.g. FSDP, TP, PP).

Do you have an example workload (e.g. model size, model type, input size)?

@lucasjinreal As @awgu mentioned, any form of parallelisms or AC might help reduce the peak GPU memory usage.

If you have applied some of the techniques (i.e. AC or reduced the param/optim states memory occupation with FSDP), and still couldn't train with local batch_size = 1, TP would be a effective way to further reduce the activation memory size by "ballpark" your global batch size to be < # GPUs (therefore reducing the effective local batch size, the minimal effective local batch size can be as small as 1/tp_degree).

I am currently using deepspeed zero3.

I have a model which need at least 40GB GPU mem all. But I only got 32GB, using deepseed zero3 it might can reduce some GPU mem footprint, but even bs = 1 it will still OOM.

As a result, it's very hard to make it trainable on 2x32GB cards. It seems doesn't have a machanism to departch bs=1 calculate on 2 devices.

In this situation, how will FSDP help here? (or torchtitan)

@lucasjinreal - when you hit your OOM using zero3, is it during the forward pass or backward pass? (sounds like forward pass but want to confirm).

If it's during the forward pass then it's the memory consumed from activations causing the OOM - for that you can do activation checkpointing (if you are not already) and we also have prototypes in progress for activation offloading (leverage the CPU memory).

If it's during backward pass, we'll soon be adding pagedFused AdamW that will also leverage cpu memory to offload the optimizer states.

I think it might due the forward, as am able to train all at,

Am suing zero3 config like this:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },
    "train_micro_batch_size_per_gpu": "auto",
    "train_batch_size": "auto",
    "gradient_accumulation_steps": "auto",
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }
}

I suppose it should already add any techniques that could reduce GPU mem include activation offloading?

Will FSDP in torchtitan can push to further?

@lucasjinreal We may need some more information. If there is an OOM in PyTorch, an explicit error should be raised, and the backtrace for that error can show whether the execution was in the forward or backward pass.

If you want to get precise information of the memory usage, attaching an OOM observer that saves a memory snapshot may be helpful (https://zdevito.github.io/2022/08/16/memory-snapshots.html). Perhaps, something like the following may work:

# Add this early in your script before model init for example
torch.cuda.memory._record_memory_history()

def oom_observer(device, alloc, device_alloc, device_free):
    snapshot = torch.cuda.memory._snapshot()
    rank = torch.distributed.get_rank()
    pickle.dump(snapshot, open(f"oom_snapshot_{rank}.pickle", "wb"))
torch._C._cuda_attach_out_of_memory_observer(oom_observer)

The saved .pickle files can be viewed on https://pytorch.org/memory_viz. More information on how to read the snapshot can be found in https://pytorch.org/docs/stable/torch_cuda_memory.html.

From this snapshot, we can break down whether the memory usage is mainly coming from model states or from activations.

Since we are not as familiar with DeepSpeed, we may not be able to tell whether you already have activation checkpointing enabled from your config. You may be able to ask in the DeepSpeed repo if you want confirmation, or you can save a profiler trace and examine it directly. If activation checkpointing is not enabled and activation memory usage is high, then like @wanchaol and @lessw2020 said, you can enable it to save memory.

Since you are already using DeepSpeed ZeRO-3, it is unlikely that switching to FSDP would help.

So that, it can be concluded, if one can not use zero3 trainng a model even with bs = 1, then it won't able to do so with FSDP as well?

But I heared FSDP can make two GPU cards as single one, my model can not fit bs = 1 on one card, but definitly can fit bs= 1 on two cards. Will it work on FSDP?

So that, it can be concluded, if one can not use zero3 trainng a model even with bs = 1, then it won't able to do so with FSDP as well?

FSDP and DeepSpeed ZeRO implementing the same underlying algorithm. There may be hyperparameters in the configs that you can tune to save additional memory (generally at the cost of less efficient communication). Given this, if a user has already tried aggressive memory savings configs on DeepSpeed ZeRO, then FSDP is unlikely to provide any benefit.

But I heared FSDP can make two GPU cards as single one, my model can not fit bs = 1 on one card, but definitly can fit bs= 1 on two cards. Will it work on FSDP?

It looks like you are referring to global batch size 1. If you have 2 GPUs and are using data parallelism (whether DDP, FSDP, or DeepSpeed ZeRO), then you need at least global batch size 2 (where each GPU has local batch size 1). If you want both GPUs to work on a single batch element, then you need something like tensor parallelism.

Yes, how to enable tensor parallelism, seems I need split the model into 2 GPUs, and calculate for both a single batch data.

This looks like didn't have default settings from deepspeed and FSDP with hf configs

@lucasjinreal Can you take a look at https://pytorch.org/tutorials/intermediate/TP_tutorial.html?

Note that fully sharded data parallelism (FSDP) is not tensor parallelism, so it would not be expected for tensor parallelism to be available under a config of FSDP.

HI, does there any built-in implementation to scale TP with a single config in transformers? Looks like users need to config every single layer to use TP?

I think you need to configure each layer to use TP, but you can probably make this into a loop:

# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(),
}
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)

This seems more like a fundamental property of tensor parallelism, where how you shard is operator specific, so you need to do some finer-grained specification compared to say FSDP. Some higher-level API might be able to be built on top, but for now, we need something like the above.

Looks like torchtitan able to do tensor parellel by default?

What do you mean by default?

torchtitan has already written the tensor parallel configuration for the Llama model, so you can enable TP from the .toml file.

Does it support Qwen model? Also, does multimodal model can be suppported such as LLava etc?

These models are not supported yet. Only Llama is supported for now.

I am going to mark this as closed since it seems all questions have been answered. If you have follow-ups, feel free to re-open or open a new issue.