Primary weights profiling question
afcruzs opened this issue · comments
Hello, I've been using the torch profiler and modified one of your examples (see script below) to profile using fp8_model_init
. I do see a decrease on the memory, specially on the activations, but when looking at peak memory, it does not decrease that much, in fact enabling fp8 without setting the primary weights to fp8 is worse than not enabling fp8 at all. Would you have an explanation for this behavior? The peaks seem to come from an "unknown" source the profiler doesn't capture (kernels?).
The data captured below is for a transformer with 4096 sequence length, 32 heads, 32 layers, 128 head size, and batch size of 1.
This is with transformer_engine version 1.0.0+66d91d5
and a single h100
machine, without FSDP (see code below).
No fp8 enabled
# eg to run: torchrun --standalone --nnodes=1 --nproc-per-node=1 fsdp.py --primary-weights-fp8
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import argparse
from functools import partial
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
# (c) Meta Platforms, Inc. and affiliates.
import logging
import socket
from datetime import datetime, timedelta
import torch
from torch.autograd.profiler import record_function
from torchvision import models
logging.basicConfig(
format="%(levelname)s:%(asctime)s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
def trace_handler(prof: torch.profiler.profile):
# Prefix for file names.
host_name = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{host_name}_{timestamp}"
# Construct the trace file.
prof.export_chrome_trace(f"{file_prefix}.json.gz")
# Construct the memory timeline file.
prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")
def lowercase(s):
return str(s).lower()
def torch_dtype(d):
typemap = {
'fp32' : torch.float32,
'float32' : torch.float32,
'fp16' : torch.float16,
'float16' : torch.float16,
'bf16' : torch.bfloat16,
'bfloat16' : torch.bfloat16
}
if lowercase(d) not in typemap.keys():
raise TypeError
return typemap[lowercase(d)]
te_layer_map = {
'linear': te.Linear,
'layernorm': te.LayerNorm,
'rmsnorm': te.RMSNorm,
'layernormlinear': te.LayerNormLinear,
'layernormmlp': te.LayerNormMLP,
'multiheadattention': te.MultiheadAttention,
'transformerlayer': te.TransformerLayer
}
def te_layer(l):
if lowercase(l) not in te_layer_map.keys():
raise TypeError
return te_layer_map[lowercase(l)]
def get_layer_args(args):
hidden_size = args.num_heads * args.head_dim
layer_args = (hidden_size, )
layer_kwargs = {
'params_dtype': args.dtype,
'device': 'meta' if args.defer_init else 'cuda'
}
if args.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if args.num_layers == 1 else hidden_size
layer_args += (ffn_hidden_size, )
layer_kwargs['bias'] = True
if args.layer_type == te.LayerNormMLP:
layer_kwargs['seq_length'] = args.seq_length
elif args.layer_type == te.MultiheadAttention:
layer_args += (args.num_heads, )
layer_kwargs['fuse_qkv_params'] = True
elif args.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, args.num_heads)
layer_kwargs['fuse_qkv_params'] = True
layer_kwargs['seq_length'] = args.seq_length
return layer_args, layer_kwargs
def parse_fsdp_args():
parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
"torch.distributed.fsdp.FullyShardedDataParallel strategy.")
parser.add_argument("-t", "--layer-type", type=te_layer, default=te.TransformerLayer,
choices=list(te_layer_map.values()),
help="TE module type used to construct the test model.")
parser.add_argument("--primary-weights-fp8", action="store_true", default=False)
parser.add_argument("--no-fp8", action="store_true", default=False,
help="Disables the te.fp8_autocast() context.")
parser.add_argument('-i', "--num-iters", type=int, default=10,
help="Number of dummy 'training' iterations.")
parser.add_argument('-b', "--batch-size", type=int, default=1,
help="Input batch size.")
parser.add_argument('-s', "--seq-length", type=int, default=4096,
help="Input sequence length.")
parser.add_argument('-n', "--num-heads", type=int, default=32,
help="Number of attention heads.")
parser.add_argument('-d', "--head-dim", type=int, default=128,
help="Dimension of each attention head (number of KV channels).")
parser.add_argument('-l', "--num-layers", type=int, default=32,
help="Number of modules chained together with nn.Sequential.")
parser.add_argument("--seed", type=int, default=1234,
help="PyTorch RNG seed.")
parser.add_argument("--defer-init", action="store_true",
help="Defer module parameter initialization until after FSDP sharding.")
parser.add_argument('-v', "--verbose", action="store_true", default=False,
help="Print out information from all GPUs instead of only the root GPU-0.")
parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
help="Data type for input tensor and Transformer Engine module parameters.")
return parser.parse_args()
def train(args, torch_profiler):
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
if local_rank == 0:
print(f"[GPU-0] WORLD_SIZE = {world_size}\n\n", end='')
torch.manual_seed(args.seed)
# Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
layer_args, layer_kwargs = get_layer_args(args)
if args.num_layers > 1:
te_layer_list = []
for i in range(args.num_layers):
if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
layer_kwargs['layer_number'] = i+1
te_layer_list.append(args.layer_type(*layer_args, **layer_kwargs))
te_model = nn.Sequential(*te_layer_list)
else:
# Single layer model
te_model = args.layer_type(*layer_args, **layer_kwargs)
if local_rank == 0:
print(f"[GPU-0] TransformerEngine Model:\n{te_model}\n", end='')
# Print out allocated device memory before the model parameters are sharded by FSDP
pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Pre-FSDP memory use = {pre_mem_use}MiB\n", end='')
# Wrap the model with FSDP
# NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
# controls all communication.
all_gpus = dist.new_group(backend='nccl')
fsdp_wrap_policy = always_wrap_policy
if args.layer_type == te.TransformerLayer:
# NOTE: FSDP causes illegal memory access without this special policy for Transformers
fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={te.TransformerLayer})
# te_model = FullyShardedDataParallel(te_model,
# process_group=all_gpus,
# use_orig_params=True,
# mixed_precision=MixedPrecision(
# param_dtype=args.dtype,
# reduce_dtype=torch.float32,
# ),
# sync_module_states=True,
# auto_wrap_policy=fsdp_wrap_policy)
# Print out allocated device memory after the model parameters are sharded
post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Post-FSDP memory use = {post_mem_use}MiB\n", end='')
# Fp8 setup for TE
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)
# Start and time dummy "training" iterations
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for i in range(args.num_iters):
# Generate a random input batch
x = torch.rand(args.seq_length, args.batch_size,
args.num_heads*args.head_dim).to(dtype=args.dtype).cuda()
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not args.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
y = te_model(x)
loss = y.sum()
# calculate gradient and take training step outside the fp8_autocast context
loss.backward()
optim.step()
torch_profiler.step()
del x
if local_rank == 0:
print(f"[GPU-0] Iter. {i+1}\n", end='')
end.record()
torch.cuda.synchronize()
# Print out "training" time and peak memory use stats
train_time = start.elapsed_time(end)/1000.
max_memory_alloc = int(torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") * 1e-6)
if local_rank == 0 or args.verbose:
print(f"[GPU-{local_rank}] Training Time: {train_time}s\n" +
f"[GPU-{local_rank}] Avg. Iter. Time: {train_time /args.num_iters}s\n" +
f"[GPU-{local_rank}] Peak memory use = {max_memory_alloc}MiB\n\n", end='')
if __name__ == "__main__":
args = parse_fsdp_args()
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=0, warmup=2, active=args.num_iters, repeat=1),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
on_trace_ready=trace_handler,
) as prof:
with te.fp8_model_init(enabled=args.primary_weights_fp8):
train(args, prof)
- For the non-FP8 case, the 40 GB steady-state memory usage is expected: 5B params with 8 bytes/param (2 bytes param, 2 bytes grad, 4 bytes optim state). It looks like the activations take ~20 GB.
- For FP8 w/o primary weights, it seems like there is ~3 GB overhead in the activations. I suspect this is because we cache an FP8 copy of the weights and an FP8 transpose that is used in the backward pass. This increase is counteracted by the fact we can store some activations in FP8.
- For FP8 w/ primary weights, we have a 5 GB reduction in param memory, but offset with the 3 GB overhead from FP8.
Thanks @timmoon10 - A follow-up question. I included torch's memory snapshot, and I see this interesting difference on fp8 with and w/o primary weights. I see a peak on the optimizer states w/o primary weights which I don't see without primary weights. Is this spike caused due to the fp8 copy of the weights or what does it explain it?
I just updated the original script with this in the bottom:
def torch_save_memory_snapshot():
host_name = socket.gethostname()
timestamp = datetime.now().strftime("%b_%d_%H_%M_%S")
rank = 0
folder = Path(".")
# folder.mkdir(parents=True, exist_ok=True)
file_prefix = str(folder / f"rank-{rank}_{host_name}_{timestamp}")
try:
torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
except Exception as e:
logger.error(f"Failed to capture memory snapshot {e}")
if __name__ == "__main__":
args = parse_fsdp_args()
# with torch.profiler.profile(
# enabled=False,
# activities=[
# torch.profiler.ProfilerActivity.CPU,
# torch.profiler.ProfilerActivity.CUDA,
# ],
# schedule=torch.profiler.schedule(wait=0, warmup=2, active=args.num_iters, repeat=1),
# # schedule=torch.profiler.schedule(wait=0, warmup=0, active=0, repeat=0),
# record_shapes=True,
# profile_memory=True,
# with_stack=True,
# with_flops=True,
# # on_trace_ready=trace_handler,
# on_trace_ready=lambda prof: None,
# ) as prof:
class Prof:
def step(self, *args, **kwargs):
pass
prof = Prof()
torch.cuda.memory._record_memory_history(
max_entries=100_000
)
with te.fp8_model_init(enabled=args.primary_weights_fp8):
train(args, prof)
torch_save_memory_snapshot()
I'm not 100% confident, but I suspect this is expected behavior within PyTorch Adam. It seems that without FP8 primary weights, PyTorch is choosing to use a multi-tensor implementation of the Adam step: https://github.com/pytorch/pytorch/blob/a8e7c98cb95ff97bb30a728c6b2a1ce6bff946eb/torch/optim/adam.py#L565
There was probably a deliberate decision to accept greater memory usage (proportional to model size) in exchange for fewer kernel launches (although maybe torch.compile
or nvFuser is smart enough to avoid materializing the intermediates?). On the other hand, the FP8 primary weights are weird enough that PyTorch falls back to the unfused Adam implementation, which only needs to handle a single param at a time but requires more kernel launches.
This is quite interesting, thanks @timmoon10!