CRF 的几点问题
ZacBi opened this issue · comments
这里是不是应该写成这样?
self.transitions.detach()[:, :self.tag_dictionary[self.START_TAG]] = -10000
还有这里, 既然tags已经是包含[CLS]
和[SEP]
的标签序列了, 为什么还要分别在左边和右边cat上[CLS]
和[SEP]
? 有点不解.
def _score_sentence(self, feats, tags, lens_):
start = torch.LongTensor([self.tag_dictionary[self.START_TAG]]).to(self.device)
start = start[None, :].repeat(tags.shape[0], 1)
stop = torch.LongTensor([self.tag_dictionary[self.STOP_TAG]]).to(self.device)
stop = stop[None, :].repeat(tags.shape[0], 1)
pad_start_tags = torch.cat([start, tags], 1)
pad_stop_tags = torch.cat([tags, stop], 1)
for i in range(len(lens_)):
pad_stop_tags[i, lens_[i] :] = self.tag_dictionary[self.STOP_TAG]
score = torch.FloatTensor(feats.shape[0]).to(self.device)
for i in range(feats.shape[0]):
r = torch.LongTensor(range(lens_[i])).to(self.device)
score[i] = torch.sum(
self.transitions[
pad_stop_tags[i, : lens_[i] + 1], pad_start_tags[i, : lens_[i] + 1]
]
) + torch.sum(feats[i, r, tags[i, : lens_[i]]])
return score
不好意思走错了