Make torch.compile work with fsdp and xformers
sagadre opened this issue · comments
some potentially helpful links:
So I think our issue might be a combination of bfloat16, memory_efficient_attention and torch.compile. The following script works if param_dtype
is float32
, but not if it's bfloat16
.
import functools
import torch
from torch.distributed.fsdp import MixedPrecision, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from open_lm.model import Block
from open_lm.model import create_model
from open_lm.params import parse_args
# Changing this to torch.bfloat16 leads to errors.
# File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 175, in meta_tensor
# assert not torch._C._dispatch_tls_local_exclude_set().has(
# AssertionError:
param_dtype = torch.float32
torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
args = parse_args("")
args.model = "open_lm_160m"
model = create_model(args)
model = model.to(device)
mp_policy = MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
transformer_auto_wrapper_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrapper_policy,
device_id=device,
mixed_precision=mp_policy,
use_orig_params=True,
limit_all_gathers=True,
)
model = torch.compile(model)
inputs = torch.randint(0, 10000, size=(1, 2048))
outputs, _ = model(inputs)
outputs.mean().backward()
Maybe pytorch/pytorch#112164 is related
Here's an even simpler reproduction that doesn't require open_lm. Run with torchrun --nproc_per_node 2 script.py
.
import torch
import torch.nn as nn
from torch.distributed.fsdp import MixedPrecision, FullyShardedDataParallel as FSDP
from xformers.ops import memory_efficient_attention
import xformers.ops as xops
class Layer(nn.Module):
def __init__(self, n_feat):
super().__init__()
self.linear_out = nn.Linear(n_feat, n_feat)
def forward(self, x):
B, N, C = x.shape
x = memory_efficient_attention(x, x, x, attn_bias=xops.LowerTriangularMask())
return self.linear_out(x.reshape([B, N, C]))
###
# dtype = torch.bfloat16
dtype = torch.float32
###
torch.distributed.init_process_group(backend="nccl")
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
torch.cuda.set_device(device)
FEAT_SIZE = 128
MAX_LEN = 100
BATCH_SIZE = 8
batch = torch.zeros(BATCH_SIZE, MAX_LEN, FEAT_SIZE).to(device)
mha = Layer(FEAT_SIZE).to(device)
mp_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
mha_fsdp = FSDP(mha, use_orig_params=True, device_id=device, mixed_precision=mp_policy)
compile_mha = torch.compile(mha_fsdp).to(device)
output = compile_mha(batch)
output.mean().backward()
This script will run fine, but changing dtype at the top to bfloat16 will lead to the following:
Traceback (most recent call last):
File "/home/ubuntu/research/openlm/main-tri/tmp/mha.py", line 36, in <module>
output = compile_mha(batch)
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/research/openlm/main-tri/tmp/mha.py", line 16, in forward
x = memory_efficient_attention(x, x, x, attn_bias=xops.LowerTriangularMask())
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 223, in memory_efficient_attention
return _memory_efficient_attention(
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 321, in _memory_efficient_attention
return _memory_efficient_attention_forward(
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py", line 341, in _memory_efficient_attention_forward
out, *_ = op.apply(inp, needs_gradient=False)
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/xformers/ops/fmha/flash.py", line 396, in apply
out, softmax_lse, rng_state = cls.OPERATOR(
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_ops.py", line 502, in __call__
[...]
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 502, in __call__
r = self.meta_tensor(
File "/home/ubuntu/research/openlm/venv/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py", line 175, in meta_tensor
assert not torch._C._dispatch_tls_local_exclude_set().has(
AssertionError:
Set torch._dynamo.config.verbose=True for more information
fwiw, the compile does go through if we add torch._dynamo.config.suppress_errors=True
, but the speedup on a 7b is negligible, 4350->4370 t/s/g which could just be noise. This might be because it switches into eager mode during flash attention, but I'm not sure.
Filed a bug with xformers facebookresearch/xformers#920
Closing this as we've realized torch attention + torchcompile works very well and as fast or faster than xformers attention for our use cases.