TypeError: forward() got an unexpected keyword argument 'mask' when training DALL-e through rainbow_dalle notebook
aalbersk opened this issue · comments
Hi, I received similar issue to one reported in lucidrains/DALLE2-pytorch#167
When running example training through rainbow_dalle notebook I managed to generate a dataset and train VAE, but I received error when attempting to train DALL-e:
TypeError Traceback (most recent call last)
/tmp/ipykernel_19683/156728592.py in <cell line: 2>()
1 dalle_model_file = "data/rainbow_dalle.model"
2 if not os.path.exists(dalle_model_file):
----> 3 dalle, loss_history = fit(dalle, opt, None, scheduler,
4 (captions_array[train_idx, ...], all_image_codes[train_idx, ...], captions_mask[train_idx, ...]), None, 200, 256,
5 dalle_model_file, train_dalle_batch,
/tmp/ipykernel_19683/3286150046.py in fit(model, opt, criterion, scheduler, train_x, train_y, epochs, batch_size, model_file, trainer, n_train_samples)
14 model.train()
15 opt.zero_grad()
---> 16 loss = trainer(model, train_x, train_y, rnd_idx[batch_idx:(batch_idx + batch_size)], criterion)
17 loss.backward()
18 losses.append(loss.item())
/tmp/ipykernel_19683/2510503843.py in train_dalle_batch(vae, train_data, _, idx, __)
1 def train_dalle_batch(vae, train_data, _, idx, __):
2 text, image_codes, mask = train_data
----> 3 loss = dalle(text[idx, ...], image_codes[idx, ...], mask=mask[idx, ...], return_loss=True)
4 return loss
~/python/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
...
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() got an unexpected keyword argument 'mask'
My environment is as follows:
OS: ubuntu 18.04
Python 3.9.13
Torch 1.12.0
torchvision 0.13.0
Have you observed similar issue? What solution could fix it?
Hey, yes, training works now.
Btw, to make a whole example work correctly, I needed to move all_image_code to cpu with all_image_codes = all_image_codes.cpu()
before calculating accuracy at the end.
But now everything works. Thanks for the quick fix!