ShannonAI / mrc-for-flat-nested-ner

Code for ACL 2020 paper `A Unified MRC Framework for Named Entity Recognition`

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

A modification of BertQueryNER forward function

smiles724 opened this issue · comments

Hi, I notice that the calculation in BertQueryNER forward function to compute the span matrix can be simplified.

The original code is
start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1) end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) span_matrix = torch.cat([start_extend, end_extend], 3)

However, since both start_extend and end_extend are using the same variable, we can change this code into
span_matrix = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, 2)

Is my understanding correct?