google-research / bert

TensorFlow code and pre-trained models for BERT

Home Page:https://arxiv.org/abs/1810.04805

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

plan to release SWAG code?

eveliao opened this issue · comments

Hi, I just want to know if you plan to release fine-tuning and evaluation code for SWAG dataset.
If not, I wonder if the training procedure is same as MRPC. (more specificly, label 0 for distractors and 1 for gold-ending)

For maintainability reasons we don't plan on releasing more code than what we've released (except for the gradient accumulation code that we've promised). You could train it as a binary classification, but we actually did something different where you softmax over the logits from different examples. This only requires a few lines of code but does require changing the input processing.

Let's assume your batch size is 8 and your sequence length is 128. Each SWAG example has 4 entries, the correct one and 3 incorrect ones.

  • Instead of your input_fn returning an input_ids of size [128], it should return one of size [4, 128]. Same for mask and sequence ids. So for each example, you will generate the sequences predicate ending0, predicate ending1, predicate ending2, predicate ending3. Also return a label scalar which is in an integer in the range [0, 3] to indicate what the gold ending is.

  • After batching, your model_fn will get an input of shape [8, 4, 128]. Reshape these to [32, 128] before passing them into BertModel. I.e., BERT will consider all of these independently.

  • Compute the logits as in run_classifier.py, but your "classifier layer" will just be a vector of size [768] (or whatever your hidden size is).

  • Now you have a set of logits of size [32]. Re-shape these back into [8, 4] and then compute tf.nn.log_softmax() over the 4 endings for each example. Now you have log probabilities of shape [8, 4] over the 4 endings and a label tensor of shape [8], so compute the loss exactly as you would for a classification problem.