jxhe / efficient-knnlm

Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question on get_knn_log_prob

jiaqing23 opened this issue · comments

Hi, I am reading the code. In knnlm.py, there is a line (https://github.com/jxhe/efficient-knnlm/blob/main/fairseq/knnlm.py#L267):

index_mask = torch.eq(torch.from_numpy(self.vals[knns]).long().cuda().squeeze(-1), tgt[knn_mask].unsqueeze(-1)).float()

May I know what is the purpose of this line? Is the tgt means the prediction target tokens? If so, why is the target available during testing?

Thanks!

Closing this question. Found the answer after reading the code again.

This is to calculate what is the knn probability of the ground truth label, and it is only used for validation. The metric used is perplexity so this is enough for it.