princeton-nlp / LM-BFF

[ACL 2021] LM-BFF: Better Few-shot Fine-tuning of Language Models https://arxiv.org/abs/2012.15723

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is there a way to deal with label words with multiple tokens?

ryangawei opened this issue · comments

Hi,

It seems like the model mainly deals with English and most labels contain only 1 token. However in Chinese tasks it's quite common that labels contain multiple tokens.

I found in https://github.com/princeton-nlp/LM-BFF/blob/main/src/models.py#L75, the code says,

sequence_output, pooled_output = outputs[:2]
sequence_mask_output = sequence_output[torch.arange(sequence_output.size(0)), mask_pos]

In which mask_pos have shape [batch_size,]. Is there a way I can make mask_pos into shape [batch_size, label_word_length] and use it to calculate loss of multi-token labels?

Hi,

Good question. By design LM-BFF can only handle one "token" in the vocabulary as the label word (simultaneously predicting probabilities of several continuous tokens is theoretically beyond the scope of BERT/RoBERTa). Although it might cause some trouble in languages like Chinese, there should be some words consist of several characters in the vocabulary and you can still use these words as label words.

@gaotianyu1350 Thank you for the reply. I did some modification to the loss calculation, which assumes each label word has 2 tokens,

# If mask_pos.shape == [batch_size, label_length], e.g. iflytek [batch_size, 2]
if len(mask_pos.shape) > 1:
    sequence_mask_output = torch.stack([sequence_output[i, index, :] for i, index in enumerate(mask_pos)])
    # sequence_mask_output.shape == [batch_size, 2, hidden_dim]
    prediction_mask_scores = self.cls(sequence_mask_output)
    # sequence_mask_output.shape == [batch_size, 2, vocab_size]
    logits = prediction_mask_scores[:, 0, self.label_word_list[:, 0]] * prediction_mask_scores[:, 1, self.label_word_list[:, 1]]

And the logits become the joint probability of two [MASK] predictions. Do you think it makes sense?

Hi,

You should use the "sum" instead of the "multiplication" of the two scores (prediction_mask_scores store the logits instead of the actual probabilities). Also, MLM pre-trained models are not designed (and pre-trained) for predicting the joint probabilities of two (continuous) masked tokens, though it might still work with fine-tuning.

@gaotianyu1350 Thank you for the suggestion.

You've made a good point, that MLM pre-trained models aren't trained for predicting joint probabilities, but since I'm following other baselines in CLUEbenchmark/FewCLUE, I simple mock their solutions on the issue for now.

For the two scores, since PyTorch's CrossEntropy does a log after the softmax, will it make sense to use multiplication according to log(p(t1) * p(t2)) = log(p(t1)) + log(p(t2))?

Thank you.

Hi,

The last formula is right, but prediction_mask_scores here is not the p in your last formula. It is the logits before softmax. You can sum them up or you can first do softmax on them separately, and then take multiplication, and then take cross entropy (without softmax).

@gaotianyu1350 I'll try the solution you proposed. Thank you very much for your suggestions and time!

Best,
Guoao