A modification of BertQueryNER forward function
smiles724 opened this issue · comments
smiles724 commented
Hi, I notice that the calculation in BertQueryNER forward function to compute the span matrix can be simplified.
The original code is
start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1) end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) span_matrix = torch.cat([start_extend, end_extend], 3)
However, since both start_extend and end_extend are using the same variable, we can change this code into
span_matrix = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, 2)
Is my understanding correct?