span loss的计算中关于mask的问题
dancinghui opened this issue · comments
dancinghui commented
在数据处理代码mrc_ner_dataset.py中,实体的end_position是真实end_position的后一位,即end_position+1,那么在对span做mask的时候对角线位置也应该mask掉的,即开始位置应该小于结束位置而不应该包含等于,这行代码
match_label_mask = torch.triu(match_label_mask, 0) # start should be less equal to end
是不是应该改为
match_label_mask = torch.triu(match_label_mask, 1) # start should be less to end
Yuxian Meng commented
@dancinghui 如果你指的是这一行,需要注意这里并不是最终返回的end_position,而是为了后续在context前加入query之后生成新的end_position的中间变量。最终返回的start/end labels都是真实的end_positions,并没有向后错一位。