plan to release SWAG code?
eveliao opened this issue · comments
Hi, I just want to know if you plan to release fine-tuning and evaluation code for SWAG dataset.
If not, I wonder if the training procedure is same as MRPC. (more specificly, label 0 for distractors and 1 for gold-ending)
For maintainability reasons we don't plan on releasing more code than what we've released (except for the gradient accumulation code that we've promised). You could train it as a binary classification, but we actually did something different where you softmax over the logits from different examples. This only requires a few lines of code but does require changing the input processing.
Let's assume your batch size is 8 and your sequence length is 128. Each SWAG example has 4 entries, the correct one and 3 incorrect ones.
-
Instead of your
input_fn
returning aninput_ids
of size[128]
, it should return one of size[4, 128]
. Same for mask and sequence ids. So for each example, you will generate the sequencespredicate ending0
,predicate ending1
,predicate ending2
,predicate ending3
. Also return a label scalar which is in an integer in the range[0, 3]
to indicate what the gold ending is. -
After batching, your
model_fn
will get an input of shape[8, 4, 128]
. Reshape these to[32, 128]
before passing them intoBertModel
. I.e., BERT will consider all of these independently. -
Compute the logits as in
run_classifier.py
, but your "classifier layer" will just be a vector of size[768]
(or whatever your hidden size is). -
Now you have a set of logits of size
[32]
. Re-shape these back into[8, 4]
and then computetf.nn.log_softmax()
over the 4 endings for each example. Now you have log probabilities of shape[8, 4]
over the 4 endings and a label tensor of shape[8]
, so compute the loss exactly as you would for a classification problem.