AMP + BF16 failing
jramapuram opened this issue · comments
Hi there,
Great work with dMoE! I'm trying to test dMoE with regular DDP + pytorch AMP(BF16) and I get the following error:
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
File "/miniconda/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 248, in _unscale_grads_
torch._amp_foreach_non_finite_check_and_unscale_(
I'm just wrapping your exisiting dmoe.dMoE(args)
logic.
Is this something that is currently unsupported? If I force the entire network to BF16 then everything works fine.
I've also seen some issues with AMP. I think theres something missing somewhere... but all the functions seem wrapped to me?
@mvpatel2000 : this can be worked around for moe.MoE
by force casting moe.to(torch.float32)
and AMP works fine. When doing the same with dmoe.dMoE
I get a triton error:
File "/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/miniconda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/miniconda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd
return bwd(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/megablocks/layers/mlp.py", line 270, in backward
stk.backend.triton_kernels.sdd(
File "/miniconda/lib/python3.10/site-packages/stk/backend/triton_kernels.py", line 336, in sdd
_sdd_kernel[grid](
File "/miniconda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 114, in run
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
File "", line 63, in _sdd_kernel
File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
File "/miniconda/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 26:25: ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# do matrix multiplication
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(A)
b = tl.load(B)
acc += tl.dot(a, b)
Some more small updates on AMP bugs @mvpatel2000
What works:
- glu/mlp + sparse + dmoe
- glu/mlp + sparse + moe
- glu/mlp + grouped + moe
What doesn't work:
- glu (and MLP) + grouped + dmoe
File "/miniconda/lib/python3.10/site-packages/megablocks/layers/glu.py", line 158, in forward
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
File "/miniconda/lib/python3.10/site-packages/grouped_gemm/ops.py", line 33, in gmm
return GroupedGemm.apply(a, b, batch_sizes, trans_b) File "/miniconda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/miniconda/lib/python3.10/site-packages/grouped_gemm/ops.py", line 11, in forward
return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
File "/miniconda/lib/python3.10/site-packages/grouped_gemm/backend.py", line 27, in gmm
backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
RuntimeError: Expected b.scalar_type() == torch::kBFloat16 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
@jramapuram any chance you can provide a mini repro? happy to look into it