mosaicml / llm-foundry

LLM training code for Databricks foundation models

Home Page:https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Optimizer's `all_reduce` operation for `OptimizerMonitor` returns wrong results

m1kol opened this issue · comments

Hello there! We've found the issue with the values that are logged using OptimizerMonitor to Tensorboard while using FSDP with SHARD_GRAD_OP strategy that are a result of the wrong all_reduce operation.

At first it was noticed that after resuming training launching it again with autoresume=True the norms for some values had a "jump", i.e. had different value than the step's value (gradient, parameter norm). It was also true when continuing training from a specified checkpoint without using autoresume.

After exploring what's going on, OptimizerMonitor and DecoupledAdamW code, it was found that all_reduce operation does not produce expected results. I can see that norms values for a specific attention weights matrix (name in the additional info section) are present only on ranks 0 and 1, but after optimizer's pre_reduce_metrics and dist_reduce_metrics the values on other ranks are not the same and are not the correct value. Moreover the reduced value that is supposed to have the result for all_reduce operation can have all kinds of values: very small value from previously being a zero, changing very little from being non-zero (on ranks 0 and 1) or become some random value like 5 or 45. At the same time initializing test variable in the same code block with the rank number as a value produces expected correct result.

It can be seen that all_gather operation returns wrong results that is correlated with the all_reduce operation results. See example output in additional info section.

It comes down to a fact that OptimizerMonitor logs the correct value -- that is the value on the 0 rank -- but 0 rank has wrong value. In case of using just regular DDP (not the FSDP equivalent NO_SHARD) the results and behavior are correct.

This behavior is present during pretrainning both MPT 1B and 7B modules using provided configs. Outputs were tested using mostly 8 GPUs but 4 GPUs setup was also used. In case of 4 GPUs the logged value (value on rank 0) is correct most of the time as explored weight appears not to be split among GPUs and is present only on rank 0.

Could you please look into this?

The options that are changed are:

  • tokenizer -- using our own,
  • dataset -- for 7B model that is the custom mix of streams with proportions; for 1B it is the subset of C4, provided by the scripts/train README,
  • FSDP config -- using SHARD_GRAD_OP instead of FULL_SHARD and using mixed precision (bf16 for forward, fp32 for others),
  • saving and loading folder.

Environment

Output of collect_env.py script:

PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.2
Libc version: glibc-2.31
Python version: 3.9.16 | packaged by conda-forge | (main, Feb  1 2023, 21:39:03)  [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1027-nvidia-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.54.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   48 bits physical, 48 bits virtual
CPU(s):                          256
On-line CPU(s) list:             0-255
Thread(s) per core:              2
Core(s) per socket:              64
Socket(s):                       2
NUMA node(s):                    4
Vendor ID:                       AuthenticAMD
CPU family:                      25
Model:                           1
Model name:                      AMD EPYC 7763 64-Core Processor
Stepping:                        1
Frequency boost:                 enabled
CPU MHz:                         2450.000
CPU max MHz:                     3529.0520
CPU min MHz:                     1500.0000
BogoMIPS:                        4900.00
Virtualization:                  AMD-V
L1d cache:                       4 MiB
L1i cache:                       4 MiB
L2 cache:                        64 MiB
L3 cache:                        512 MiB
NUMA node0 CPU(s):               0-31,128-159
NUMA node1 CPU(s):               32-63,160-191
NUMA node2 CPU(s):               64-95,192-223
NUMA node3 CPU(s):               96-127,224-255
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni
pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx
 cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero
irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] pytorch-ranger==0.1.1
[pip3] pytorch-triton==2.0.0+0d7e753227
[pip3] torch==2.0.1
[pip3] torch-optimizer==0.3.0
[pip3] torchaudio==2.0.0.dev20230205+cu117
[pip3] torchdata==0.6.1
[pip3] torchmetrics==0.11.4
[pip3] torchtext==0.15.2
[pip3] torchvision==0.15.0.dev20230205+cu117
[pip3] triton==2.0.0
[pip3] triton-pre-mlir==2.0.0
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] pytorch-ranger            0.1.1                    pypi_0    pypi
[conda] pytorch-triton            2.0.0+0d7e753227          pypi_0    pypi
[conda] torch                     2.0.1                    pypi_0    pypi
[conda] torch-optimizer           0.3.0                    pypi_0    pypi
[conda] torchaudio                2.0.0.dev20230205+cu117          pypi_0    pypi
[conda] torchdata                 0.6.1                    pypi_0    pypi
[conda] torchmetrics              0.11.4                   pypi_0    pypi
[conda] torchtext                 0.15.2                   pypi_0    pypi
[conda] torchvision               0.15.0.dev20230205+cu117          pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi
[conda] triton-pre-mlir           2.0.0                    pypi_0    pypi

To reproduce

Steps to reproduce the behavior:

  1. Clone composer (dev branch, the default one) from the GitHub repo and install it using -e option to be able to modify it.
  2. Use MPT 1B config, C4 subset from the README, SHARD_GRAD_OP strategy for FSDP config, bf16 mixed precision for forward, fp32 for buffer and reduce.
  3. Print values for block 12 attention matrix before and after all_reduce operation in composer's optimizer to get logged L2 norms (NAME).
  4. Run training with scripts/train/train.py.

Expected behavior

Parameter's (and gradients) L2 value that is the same across all ranks and that is the same with the one that can be obtained from state_dict.

Additional context

Layer name: model._fsdp_wrapped_module.transformer.blocks.12._fsdp_wrapped_module.attn.Wqkv.weight.

Used test config for 1B MPT model:

data_local: /path/to/downloaded/c4/subset
data_remote: # If blank, files must be present in data_local
max_seq_len: 2048
global_seed: 20

# Run Name
# If left blank, will be read from env var $RUN_NAME
run_name: debug_run

# Model
model:
  name: mpt_causal_lm
  init_device: meta
  d_model: 2048
  n_heads: 16 # Modified 24->16 so that d_head == 128 to statisfy FlashAttention
  n_layers: 24
  expansion_ratio: 4
  max_seq_len: ${max_seq_len}
  vocab_size: 131072  # new vocab size
  attn_config:
    attn_impl: triton
    alibi: true
    alibi_bias_max: 8

# Tokenizer
tokenizer:
  name: /path/to/tokenizer/dir
  kwargs:
    model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: train_small
    shuffle: true
    max_seq_len: ${max_seq_len}
    shuffle_seed: ${global_seed}
  drop_last: true
  num_workers: 8
  cache_limit: "512gb"

eval_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split: val_small
    shuffle: false
    max_seq_len: ${max_seq_len}
    shuffle_seed: ${global_seed}
  drop_last: false
  num_workers: 8
  cache_limit: "512gb"

# Optimization
scheduler:
  name: cosine_with_warmup
  t_warmup: 100ba
  alpha_f: 0.1

optimizer:
  name: decoupled_adamw
  lr: 2.0e-4
  betas:
  - 0.9
  - 0.95
  eps: 1.0e-08
  weight_decay: 0.0

algorithms:
  gradient_clipping:
    clipping_type: norm
    clipping_threshold: 1.0

max_duration: 1000ba
eval_interval: 200ba
eval_first: false
eval_subset_num_batches: -1
global_train_batch_size: 512

# System
seed: ${global_seed}
device_eval_batch_size: 4
device_train_microbatch_size: 4
# device_train_microbatch_size: auto
precision: amp_bf16

# FSDP
fsdp_config:
  sharding_strategy: SHARD_GRAD_OP  # SHARD_GRAD_OP  # ZeRO-2 equivalent; FULL_SHARD (default) is ZeRO-3
  mixed_precision:
    param_dtype: bf16
    reduce_dtype: fp32
    buffer_dtype: fp32
  activation_checkpointing: false
  activation_checkpointing_reentrant: false
  activation_cpu_offload: false
  limit_all_gathers: true
  verbose: false

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 50ba

callbacks:
  speed_monitor:
    window_size: 50
  lr_monitor: {}
  memory_monitor: {}
  runtime_estimator: {}
  optimizer_monitor:
    batch_log_interval: 10

loggers:
  tensorboard:
    log_dir: /path/to/tensorboard/dir
    rank_zero_only: true

# Checkpoint to local filesystem or remote object store
# save_num_checkpoints_to_keep: 1  # Important, this cleans up checkpoints saved to DISK
save_interval: 200ba
save_folder: /path/to/save/dir

# Load from local filesystem or remote object store
autoresume: false
# load_path: /path/to/checkpoint/if/exists

Code block for dist_reduce_metrics for DecoupledAdamW to print values:

        all_keys = sorted(all_keys, key=lambda metric: 0 if 'l2_norm' in metric else 1)
        for metric in all_keys:
            if metric.startswith('l2_norm'):
                dist.barrier()

                tensors_gather_list = [
                    torch.tensor(0.0, dtype=torch.float32, device=torch.cuda.current_device())
                    for _ in range(dist.get_world_size())
                ]
                # tensors_gather_list = None

                test_reduce = torch.tensor(rank, dtype=torch.float32, device=torch.cuda.current_device())

                if metric in optimizer_metrics:
                    reduced_l2 = torch.tensor(optimizer_metrics[metric].item(), dtype=torch.float32, device=torch.cuda.current_device())
                else:
                    reduced_l2 = torch.tensor(0.0, dtype=torch.float32, device=torch.cuda.current_device())

                # reduced = optimizer_metrics.get(metric, torch.tensor(0.0, device=torch.cuda.current_device()))

                # ------------------- DEBUG PRINTS -------------------
                if metric == opt_metrics_name:
                    print(f'rank {rank} | dist_reduce_metrics initial reduced value: {reduced_l2}')
                    print(f'rank {rank} | dist_reduce_metrics initial reduced value device: {reduced_l2.device}')

                    print(f'rank {rank} | test_reduce value: {test_reduce}')
                    print(f'rank {rank} | test_reduce value device: {test_reduce.device}')

                    print('--')
                # ------------------- DEBUG PRINTS -------------------

                if dist.get_world_size() > 1:
                    dist.barrier()

                    # tensors_gather_list = dist.all_gather(test_reduce)
                    torch.distributed.all_gather(tensors_gather_list, reduced_l2)

                    # torch.distributed.reduce(reduced_l2, dst=0)
                    dist.all_reduce(reduced_l2, reduce_operation='SUM')


                    dist.all_reduce(test_reduce, reduce_operation='SUM')

                dist.barrier()

                # ------------------- DEBUG PRINTS -------------------
                if metric == opt_metrics_name:
                    print(f'rank {rank} | tensors_gather_list: {tensors_gather_list}')

                    print(f'rank {rank} | dist_reduce_metrics reduced value: {reduced_l2}')
                    print(f'rank {rank} | dist_reduce_metrics reduced value sq root: {math.sqrt(reduced_l2)}')

                    print(f'rank {rank} | test_reduce value reduced: {test_reduce}')
                    print(f'rank {rank} | test_reduce value reduced device: {test_reduce.device}')

                    print('--------------------------------------------------------')
                # ------------------- DEBUG PRINTS -------------------

                dist.barrier()

                optimizer_metrics[metric] = math.sqrt(reduced_l2)

Example output:

example_output

example_output_v2

After looking at the issue decided to also post it to composer repo, that is already mentioned above.

Link to composer issue for mention there

Found that the issue were in the non-consistent optimizer_metrics keys ordering across ranks so all_reduce operation was potentially performed on different entities (params, grads, moments) of different layer.

Closing the issue here as it's about composer itself.