Simplest way to modify owt pretraining to use ReformerLM as generator and discriminator
DarrenAbramson opened this issue · comments
I tried the obvious thing, and modified pretraining/openwebtext/pretrain.py
with the readme idea of using the Reformer implementation for generator and discriminator as follows. Notice the removal of the LogitsAdapter
wrapper over the generator and discriminator. This doesn't work, but I'd be interested in suggestions. Error below.
from transformers import AutoConfig, ElectraForMaskedLM, ElectraForPreTraining
from reformer_pytorch import ReformerLM
#generator = ElectraForMaskedLM(AutoConfig.from_pretrained(args.model_generator))
#discriminator = ElectraForPreTraining(AutoConfig.from_pretrained(args.model_discriminator))
generator = ReformerLM(
num_tokens = 30522,
emb_dim = 128,
dim = 256, # smaller hidden dimension
heads = 4, # less heads
ff_mult = 2, # smaller feed forward intermediate dimension
dim_head = 64,
depth = 12,
max_seq_len = 768
)
discriminator = ReformerLM(
num_tokens = 30522,
emb_dim = 128,
dim = 1024,
dim_head = 64,
heads = 16,
depth = 12,
ff_mult = 4,
max_seq_len = 768
)
# (2) weight tie the token and positional embeddings of generator and discriminator
generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb
# tie_weights(generator, discriminator)
model = to_distributed_model(Electra(
generator,
discriminator,
num_tokens = vocab_size,
mask_token_id = mask_token_id,
pad_token_id = pad_token_id,
mask_prob = args.model_mask_prob,
mask_ignore_token_ids = [tokenizer.vocab['[CLS]'], tokenizer.vocab['[SEP]']],
random_token_prob = 0.0).to(device))
Error:
Traceback (most recent call last):
File "pretraining/openwebtext/pretrain_reformer.py", line 348, in <module>
main()
File "pretraining/openwebtext/pretrain_reformer.py", line 344, in main
train(rank=args.gpu, args=args)
File "pretraining/openwebtext/pretrain_reformer.py", line 228, in train
loss, loss_mlm, loss_disc, acc_gen, acc_disc, disc_labels, disc_pred = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids)
File ".../tf2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File ".../lucidrains/electra-pytorch/electra_pytorch/electra_pytorch.py", line 210, in forward
disc_logits = disc_logits.reshape_as(disc_labels)
RuntimeError: shape '[4, 768]' is invalid for input of size 93763584
@DarrenAbramson Hey Darren again! I actually would recommend going with full attention instead of LSH attention for pretraining. Perhaps the example I have in the readme is misleading. Are you working with sequence lengths > 2048?