princeton-nlp / LM-BFF

[ACL 2021] LM-BFF: Better Few-shot Fine-tuning of Language Models https://arxiv.org/abs/2012.15723

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Average the logits of 16 demonstrations

Raibows opened this issue · comments

Hi!
Really thanks for your work. And I have a question about prediction at inference with demonstrations.
As the paper mentioned, the final prediction logits is by averaging the the results of 16 demonstrations.
But I cannot find such an average operation, and I only find you augment the dataset with 16 times larger, that means every input will accompany with 16 demonstrations and then there should be an operation to gather these 16 concated texts (input+demo) to get the final predicton logits. The pseudo code is like below

Require: the prediction logits of an input text_a
Input: 16 concated inputs (text_a, demo1) .... (text_a, demo16)
predictions = [16, num_classes]
Output: torch.mean(predictions, dim=0) # [1, num_classes]

Can you help me find this operation in the code? Or maybe I have some mis understandings...
Thanks!

Hi,

The average operation is here:

LM-BFF/run.py

Line 495 in 1bbdc42

def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: