Question on get_knn_log_prob
jiaqing23 opened this issue · comments
Tan Jia Qing commented
Hi, I am reading the code. In knnlm.py
, there is a line (https://github.com/jxhe/efficient-knnlm/blob/main/fairseq/knnlm.py#L267):
index_mask = torch.eq(torch.from_numpy(self.vals[knns]).long().cuda().squeeze(-1), tgt[knn_mask].unsqueeze(-1)).float()
May I know what is the purpose of this line? Is the tgt
means the prediction target tokens? If so, why is the target available during testing?
Thanks!
Tan Jia Qing commented
Closing this question. Found the answer after reading the code again.
This is to calculate what is the knn probability of the ground truth label, and it is only used for validation. The metric used is perplexity so this is enough for it.