kaituoxu / Speech-Transformer

A PyTorch implementation of Speech Transformer, an End-to-End ASR with Transformer network on Mandarin Chinese.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

question about the non_pad_mask

songtaoshi opened this issue · comments

Hi kaituo, thanks for sharing such a useful speech transformer.
I have run your code and got the reported results successfully but I have a question about the non_pad_mask, I notice that when calculating the loss, we have masked the padded part. So since we have restricted the loss calculation, why we need to restrict step by step when forwarding.

  • forward
    def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):

        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask)
        dec_output *= non_pad_mask

        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output *= non_pad_mask

        dec_output = self.pos_ffn(dec_output)
        dec_output *= non_pad_mask
  • loss calculating
        loss = loss.masked_select(non_pad_mask).sum() / n_word

Just make sure the computation is correct for the padding part.

got it, thanks for your reply!