OpenBMB / CPM-Live

Live Training for Open-source Big Models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Puzzled in mask operation

DeepTecher opened this issue · comments

Thank you for your good work. However, I have some doubts about the following code: (Source in ant_torch.py#L144 when I ran Ant model.)

  1. What is the main logic of this part? I did not get it.
  2. when inferencing, context is all set to True, and span is all set to 0 on _convert_to_tensors, and it seems mask is all to 1 after the following code. So what do those codes?
with torch.no_grad():
      device = input.device
      directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(
          seqlen, device=device
      ).view(-1, 1)
      attention_mask = context[:, None, :] | (
          context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
      )
      attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
      mask_1d = (
          torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
      )
      attention_mask = (
          mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
      )

Hi,

  1. As we use a unified architecture of Transformer, the attention of target part is unidirectional (a.k.a causal), and the attention of context part is bidirectional. That's the main logic of this code snippet.
  2. When inferencing, every thing is context.

You can refer to our blog for more details of the model.

Smart design in MASK operation when I dived into your code. However, It still exists another problem for _process_text:pad when the batch is greater than 1.
Take example:

我们在假期去了法国的埃菲尔铁塔,
今天天气真好

After the _process_text,input_id is as follows

tensor([[   64,    65,    66,    67,    68,    69,    70,    71,    72,    73,
            74,    75,    76,    77,    78,    79,    80,    81,    82,    83,
            84,    85,    86,    87,    88,    89,    90,    91,    92,    93,
            94,    95,     6, 18039,   638, 10547,   454,   267, 21744,  1405,
          2064,  2452,   536,  2182,  2760,   245],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,    64,
            65,    66,    67,    68,    69,    70,    71,    72,    73,    74,
            75,    76,    77,    78,    79,    80,    81,    82,    83,    84,
            85,    86,    87,    88,    89,    90,    91,    92,    93,    94,
            95,     6,  9802, 14962,  2082,   831]], device='cuda:0',
       dtype=torch.int32)

When ant_torch forward, I think the value of prompt_states and hidden_states is not good especially the two embedding parameters are not the same. (problem both on forward and inference)
The code is listed below, ori here

input_prompt = input[:, : self.prompt_length].contiguous()
input_ids = input[:, self.prompt_length :].contiguous()

prompt_states = self.prompt_embedding(input_prompt)
hidden_states = self.input_embedding(input_ids)

input_prompt

tensor([[64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
         82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0, 64, 65, 66, 67, 68, 69, 70, 71, 72,
         73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86]],
       device='cuda:0', dtype=torch.int32)

input_id

tensor([[    6, 18039,   638, 10547,   454,   267, 21744,  1405,  2064,  2452,
           536,  2182,  2760,   245],
        [   87,    88,    89,    90,    91,    92,    93,    94,    95,     6,
          9802, 14962,  2082,   831]], device='cuda:0', dtype=torch.int32)

Obviously, this data is incorrect especially input_id and input_prompt that are converted by short text. In this case, input_id[1] and input_prompt[1]. (I guess prompt should begin from 64 not 0, input should begin from 6 not one of [64-95]).

I hope my understanding is correct and look forward to your answers.

You are right, there are some bugs related to prompt tokens in CPM-Ant. We have unified the prompt and input embeddings in CPM-Ant+, see #148, which fixed these kinds of issues.

okay...
It will be great if we can fix it on Ant model, as it has affected the results in different cases(batch=1, batch>1)


One of the ways I can think of is as follows:
Do some data transformations (though it will be a slight time loss) to align by using var: LENGTH on model.inference and model.forward, flip back

Actually, left padding is only applied to inference.
It will be great if you can create a PR to fix the issue in the inference stage.