NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Primary weights profiling question

afcruzs opened this issue · comments

Hello, I've been using the torch profiler and modified one of your examples (see script below) to profile using fp8_model_init. I do see a decrease on the memory, specially on the activations, but when looking at peak memory, it does not decrease that much, in fact enabling fp8 without setting the primary weights to fp8 is worse than not enabling fp8 at all. Would you have an explanation for this behavior? The peaks seem to come from an "unknown" source the profiler doesn't capture (kernels?).

The data captured below is for a transformer with 4096 sequence length, 32 heads, 32 layers, 128 head size, and batch size of 1.

This is with transformer_engine version 1.0.0+66d91d5 and a single h100 machine, without FSDP (see code below).

No fp8 enabled

image

Fp8 without primary weights
image

Fp8 with primary weights
image

# eg to run:  torchrun --standalone --nnodes=1 --nproc-per-node=1 fsdp.py --primary-weights-fp8
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import argparse
from functools import partial

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
# (c) Meta Platforms, Inc. and affiliates. 
import logging
import socket
from datetime import datetime, timedelta

import torch

from torch.autograd.profiler import record_function
from torchvision import models

logging.basicConfig(
   format="%(levelname)s:%(asctime)s %(message)s",
   level=logging.INFO,
   datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

def trace_handler(prof: torch.profiler.profile):
   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   # Construct the trace file.
   prof.export_chrome_trace(f"{file_prefix}.json.gz")

   # Construct the memory timeline file.
   prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")

def lowercase(s):
    return str(s).lower()

def torch_dtype(d):
    typemap = {
        'fp32' : torch.float32,
        'float32' : torch.float32,
        'fp16' : torch.float16,
        'float16' : torch.float16,
        'bf16' : torch.bfloat16,
        'bfloat16' : torch.bfloat16
    }
    if lowercase(d) not in typemap.keys():
        raise TypeError
    return typemap[lowercase(d)]

te_layer_map = {
    'linear': te.Linear,
    'layernorm': te.LayerNorm,
    'rmsnorm': te.RMSNorm,
    'layernormlinear': te.LayerNormLinear,
    'layernormmlp': te.LayerNormMLP,
    'multiheadattention': te.MultiheadAttention,
    'transformerlayer': te.TransformerLayer
}
def te_layer(l):
    if lowercase(l) not in te_layer_map.keys():
        raise TypeError
    return te_layer_map[lowercase(l)]

def get_layer_args(args):
    hidden_size = args.num_heads * args.head_dim
    layer_args = (hidden_size, )
    layer_kwargs = {
        'params_dtype': args.dtype,
        'device': 'meta' if args.defer_init else 'cuda'
    }
    if args.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
        ffn_hidden_size = 3 * hidden_size if args.num_layers == 1 else hidden_size
        layer_args += (ffn_hidden_size, )
        layer_kwargs['bias'] = True
        if args.layer_type == te.LayerNormMLP:
            layer_kwargs['seq_length'] = args.seq_length
    elif args.layer_type == te.MultiheadAttention:
        layer_args += (args.num_heads, )
        layer_kwargs['fuse_qkv_params'] = True
    elif args.layer_type == te.TransformerLayer:
        layer_args += (3 * hidden_size, args.num_heads)
        layer_kwargs['fuse_qkv_params'] = True
        layer_kwargs['seq_length'] = args.seq_length
    return layer_args, layer_kwargs

def parse_fsdp_args():
    parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
                                    "torch.distributed.fsdp.FullyShardedDataParallel strategy.")
    parser.add_argument("-t", "--layer-type", type=te_layer, default=te.TransformerLayer,
                        choices=list(te_layer_map.values()),
                        help="TE module type used to construct the test model.")
    
    parser.add_argument("--primary-weights-fp8", action="store_true", default=False)

    parser.add_argument("--no-fp8", action="store_true", default=False,
                        help="Disables the te.fp8_autocast() context.")
    parser.add_argument('-i', "--num-iters", type=int, default=10,
                        help="Number of dummy 'training' iterations.")
    parser.add_argument('-b', "--batch-size", type=int, default=1,
                        help="Input batch size.")
    parser.add_argument('-s', "--seq-length", type=int, default=4096,
                        help="Input sequence length.")
    parser.add_argument('-n', "--num-heads", type=int, default=32,
                        help="Number of attention heads.")
    parser.add_argument('-d', "--head-dim", type=int, default=128,
                        help="Dimension of each attention head (number of KV channels).")
    parser.add_argument('-l', "--num-layers", type=int, default=32,
                        help="Number of modules chained together with nn.Sequential.")
    parser.add_argument("--seed", type=int, default=1234,
                        help="PyTorch RNG seed.")
    parser.add_argument("--defer-init", action="store_true",
                        help="Defer module parameter initialization until after FSDP sharding.")
    parser.add_argument('-v', "--verbose", action="store_true", default=False,
                        help="Print out information from all GPUs instead of only the root GPU-0.")
    parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
                        help="Data type for input tensor and Transformer Engine module parameters.")
    return parser.parse_args()

def train(args, torch_profiler):
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    # Initialize torch.distributed global process group
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(local_rank)
    if local_rank == 0:
        print(f"[GPU-0] WORLD_SIZE = {world_size}\n\n", end='')
    torch.manual_seed(args.seed)

    # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
    layer_args, layer_kwargs = get_layer_args(args)
    if args.num_layers > 1:
        te_layer_list = []
        for i in range(args.num_layers):
            if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
                layer_kwargs['layer_number'] = i+1
                te_layer_list.append(args.layer_type(*layer_args, **layer_kwargs))
        te_model = nn.Sequential(*te_layer_list)
    else:
        # Single layer model
        te_model = args.layer_type(*layer_args, **layer_kwargs)
    if local_rank == 0:
        print(f"[GPU-0] TransformerEngine Model:\n{te_model}\n", end='')

    # Print out allocated device memory before the model parameters are sharded by FSDP
    pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
    if local_rank == 0 or args.verbose:
        print(f"[GPU-{local_rank}] Pre-FSDP memory use = {pre_mem_use}MiB\n", end='')

    # Wrap the model with FSDP
    # NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
    #       controls all communication.
    all_gpus = dist.new_group(backend='nccl')
    fsdp_wrap_policy = always_wrap_policy
    if args.layer_type == te.TransformerLayer:
        # NOTE: FSDP causes illegal memory access without this special policy for Transformers
        fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
                                   transformer_layer_cls={te.TransformerLayer})
    # te_model = FullyShardedDataParallel(te_model,
    #                                     process_group=all_gpus,
    #                                     use_orig_params=True,
    #                                     mixed_precision=MixedPrecision(
    #                                         param_dtype=args.dtype,
    #                                         reduce_dtype=torch.float32,
    #                                     ),
    #                                     sync_module_states=True,
    #                                     auto_wrap_policy=fsdp_wrap_policy)

    # Print out allocated device memory after the model parameters are sharded
    post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
    if local_rank == 0 or args.verbose:
        print(f"[GPU-{local_rank}] Post-FSDP memory use = {post_mem_use}MiB\n", end='')

    # Fp8 setup for TE
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

    # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
    optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)

    # Start and time dummy "training" iterations
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    start.record()
    for i in range(args.num_iters):
        # Generate a random input batch
        x = torch.rand(args.seq_length, args.batch_size,
                       args.num_heads*args.head_dim).to(dtype=args.dtype).cuda()
        # fp8_autocast needs to be given the FSDP process group for amax reductions
        with te.fp8_autocast(enabled=not args.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
            y = te_model(x)
            loss = y.sum()
        # calculate gradient and take training step outside the fp8_autocast context
        loss.backward()
        optim.step()
        torch_profiler.step()
        del x
        if local_rank == 0:
            print(f"[GPU-0] Iter. {i+1}\n", end='')
    end.record()
    torch.cuda.synchronize()

    # Print out "training" time and peak memory use stats
    train_time = start.elapsed_time(end)/1000.
    max_memory_alloc = int(torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") * 1e-6)
    if local_rank == 0 or args.verbose:
        print(f"[GPU-{local_rank}] Training Time: {train_time}s\n" +
              f"[GPU-{local_rank}] Avg. Iter. Time: {train_time /args.num_iters}s\n" +
              f"[GPU-{local_rank}] Peak memory use = {max_memory_alloc}MiB\n\n", end='')


if __name__ == "__main__":
    
    args = parse_fsdp_args()
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(wait=0, warmup=2, active=args.num_iters, repeat=1),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
        with_flops=True,
        on_trace_ready=trace_handler,
    ) as prof:    
        with te.fp8_model_init(enabled=args.primary_weights_fp8):
            train(args, prof)

  • For the non-FP8 case, the 40 GB steady-state memory usage is expected: 5B params with 8 bytes/param (2 bytes param, 2 bytes grad, 4 bytes optim state). It looks like the activations take ~20 GB.
  • For FP8 w/o primary weights, it seems like there is ~3 GB overhead in the activations. I suspect this is because we cache an FP8 copy of the weights and an FP8 transpose that is used in the backward pass. This increase is counteracted by the fact we can store some activations in FP8.
  • For FP8 w/ primary weights, we have a 5 GB reduction in param memory, but offset with the 3 GB overhead from FP8.

Thanks @timmoon10 - A follow-up question. I included torch's memory snapshot, and I see this interesting difference on fp8 with and w/o primary weights. I see a peak on the optimizer states w/o primary weights which I don't see without primary weights. Is this spike caused due to the fp8 copy of the weights or what does it explain it?

W/o primary weights:
image

With primary weights
image

I just updated the original script with this in the bottom:


def torch_save_memory_snapshot():
    host_name = socket.gethostname()
    timestamp = datetime.now().strftime("%b_%d_%H_%M_%S")
    rank = 0
    folder = Path(".")
    # folder.mkdir(parents=True, exist_ok=True)
    file_prefix = str(folder / f"rank-{rank}_{host_name}_{timestamp}")
    try:
        torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
    except Exception as e:
        logger.error(f"Failed to capture memory snapshot {e}")

if __name__ == "__main__":
    
    args = parse_fsdp_args()
    # with torch.profiler.profile(
    #     enabled=False,
    #     activities=[
    #         torch.profiler.ProfilerActivity.CPU,
    #         torch.profiler.ProfilerActivity.CUDA,
    #     ],
    #     schedule=torch.profiler.schedule(wait=0, warmup=2, active=args.num_iters, repeat=1),
    #     # schedule=torch.profiler.schedule(wait=0, warmup=0, active=0, repeat=0),
    #     record_shapes=True,
    #     profile_memory=True,
    #     with_stack=True,
    #     with_flops=True,
    #     # on_trace_ready=trace_handler,
    #     on_trace_ready=lambda prof: None,
    # ) as prof:
    class Prof:
        def step(self, *args, **kwargs):
            pass
    prof = Prof()
    torch.cuda.memory._record_memory_history(
        max_entries=100_000
    )
    with te.fp8_model_init(enabled=args.primary_weights_fp8):
        train(args, prof)

    torch_save_memory_snapshot()

I'm not 100% confident, but I suspect this is expected behavior within PyTorch Adam. It seems that without FP8 primary weights, PyTorch is choosing to use a multi-tensor implementation of the Adam step: https://github.com/pytorch/pytorch/blob/a8e7c98cb95ff97bb30a728c6b2a1ce6bff946eb/torch/optim/adam.py#L565
There was probably a deliberate decision to accept greater memory usage (proportional to model size) in exchange for fewer kernel launches (although maybe torch.compile or nvFuser is smart enough to avoid materializing the intermediates?). On the other hand, the FP8 primary weights are weird enough that PyTorch falls back to the unfused Adam implementation, which only needs to handle a single param at a time but requires more kernel launches.

This is quite interesting, thanks @timmoon10!