deepseek-ai / DeepSeek-Coder-V2

DeepSeek-Coder-V2: Breaking the Barrier of Closed-Source Models in Code Intelligence

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CanNot Finetune deepseek-coder-v2-lite via modeling_deepseek.py

chencyudel opened this issue · comments

You may have som bug on type manipulation and thus the model can not be finetuned via DeepSpeed(bf16 mix precision)

File "/deepseek_v2/modeling_deepseek.py", line 1252, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/deepseek_v2/modeling_deepseek.py", line 953, in forward
q = self.q_proj(hidden_states)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

@lihaoling I have found the reason which is the "forward" problem in DeepseekV2MoE module. An easy but maybe not best fix is to cast the dtype of hidden_states back at the end of "forward" method of DeepseekV2MoE.
For example:

def forward(self, hidden_states):
        # save dtype before computation
        input_dtype = hidden_states.dtype
        identity = hidden_states
        orig_shape = hidden_states.shape
        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        flat_topk_idx = topk_idx.view(-1)
        if self.training:
            hidden_states = hidden_states.repeat_interleave(
                self.num_experts_per_tok, dim=0
            )
            y = torch.empty_like(hidden_states)
            for i, expert in enumerate(self.experts):
                y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
            y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
            y = y.view(*orig_shape)
            y = AddAuxiliaryLoss.apply(y, aux_loss)
        else:
            y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
        if self.config.n_shared_experts is not None:
            y = y + self.shared_experts(identity)
        # keep dtype same after moe forward
        return y.to(input_dtype)

@chencyudel Thanks! I have it running successfully.

I‘ve also encountered this problem ans solved it by the way above. I'm wondering if there is more elegant way to solve it ?

I guess this is due to mixed precision training, e.g. this paper Switch Transformers uses float32 precision in the router function and bfloat 16 precision in others. But not sure there is an elegant way to replace such a dtype conversion.

Thanks a lot! I have it running successfully, but it looks so slow. I used deepspeed zero stage 3 for distributed training, anyone who encountered similar situation?

I guess this is due to mixed precision training, e.g. this paper Switch Transformers uses float32 precision in the router function and bfloat 16 precision in others. But not sure there is an elegant way to replace such a dtype conversion.

But when I use bfloat16 to load the model for inference, the error won't come out. It seems that during the forward process of inference, the fp32 in the router function will be transformed into bf16 automatically.