mlfoundations / open_lm

A repository for research on medium sized language models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Make torch.compile work with fsdp and xformers

sagadre opened this issue · comments

some potentially helpful links:

So I think our issue might be a combination of bfloat16, memory_efficient_attention and torch.compile. The following script works if param_dtype is float32, but not if it's bfloat16.

import functools

import torch
from torch.distributed.fsdp import MixedPrecision, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from open_lm.model import Block
from open_lm.model import create_model
from open_lm.params import parse_args

# Changing this to torch.bfloat16 leads to errors.
#   File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 175, in meta_tensor
#       assert not torch._C._dispatch_tls_local_exclude_set().has(
#   AssertionError:
param_dtype = torch.float32

torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)

args = parse_args("")
args.model = "open_lm_160m"
model = create_model(args)
model = model.to(device)

mp_policy = MixedPrecision(
    param_dtype=param_dtype,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)
transformer_auto_wrapper_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrapper_policy,
    device_id=device,
    mixed_precision=mp_policy,
    use_orig_params=True,
    limit_all_gathers=True,
)
model = torch.compile(model)
inputs = torch.randint(0, 10000, size=(1, 2048))
outputs, _ = model(inputs)
outputs.mean().backward()

Here's an even simpler reproduction that doesn't require open_lm. Run with torchrun --nproc_per_node 2 script.py.

import torch
import torch.nn as nn

from torch.distributed.fsdp import MixedPrecision, FullyShardedDataParallel as FSDP
from xformers.ops import memory_efficient_attention
import xformers.ops as xops


class Layer(nn.Module):
    def __init__(self, n_feat):
        super().__init__()
        self.linear_out = nn.Linear(n_feat, n_feat)

    def forward(self, x):
        B, N, C = x.shape
        x = memory_efficient_attention(x, x, x, attn_bias=xops.LowerTriangularMask())
        return self.linear_out(x.reshape([B, N, C]))

###
# dtype = torch.bfloat16
dtype = torch.float32
###

torch.distributed.init_process_group(backend="nccl")
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
torch.cuda.set_device(device)
FEAT_SIZE = 128
MAX_LEN = 100
BATCH_SIZE = 8

batch = torch.zeros(BATCH_SIZE, MAX_LEN, FEAT_SIZE).to(device)
mha = Layer(FEAT_SIZE).to(device)

mp_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
mha_fsdp = FSDP(mha, use_orig_params=True, device_id=device, mixed_precision=mp_policy)

compile_mha = torch.compile(mha_fsdp).to(device)
output = compile_mha(batch)
output.mean().backward()

This script will run fine, but changing dtype at the top to bfloat16 will lead to the following:

Traceback (most recent call last):
  File "/home/ubuntu/research/openlm/main-tri/tmp/mha.py", line 36, in <module>
    output = compile_mha(batch)
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/research/openlm/main-tri/tmp/mha.py", line 16, in forward
    x = memory_efficient_attention(x, x, x, attn_bias=xops.LowerTriangularMask())
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 223, in memory_efficient_attention
    return _memory_efficient_attention(
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 321, in _memory_efficient_attention
    return _memory_efficient_attention_forward(
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 341, in _memory_efficient_attention_forward
    out, *_ = op.apply(inp, needs_gradient=False)
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/xformers/ops/fmha/flash.py", line 396, in apply
    out, softmax_lse, rng_state = cls.OPERATOR(
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_ops.py", line 502, in __call__
[...]
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 502, in __call__
    r = self.meta_tensor(
  File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 175, in meta_tensor
    assert not torch._C._dispatch_tls_local_exclude_set().has(
AssertionError:

Set torch._dynamo.config.verbose=True for more information

fwiw, the compile does go through if we add torch._dynamo.config.suppress_errors=True, but the speedup on a 7b is negligible, 4350->4370 t/s/g which could just be noise. This might be because it switches into eager mode during flash attention, but I'm not sure.

Filed a bug with xformers facebookresearch/xformers#920

Closing this as we've realized torch attention + torchcompile works very well and as fast or faster than xformers attention for our use cases.