lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: forward() got an unexpected keyword argument 'mask' when training DALL-e through rainbow_dalle notebook

aalbersk opened this issue · comments

Hi, I received similar issue to one reported in lucidrains/DALLE2-pytorch#167

When running example training through rainbow_dalle notebook I managed to generate a dataset and train VAE, but I received error when attempting to train DALL-e:

TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_19683/156728592.py in <cell line: 2>()
      1 dalle_model_file = "data/rainbow_dalle.model"
      2 if not os.path.exists(dalle_model_file):
----> 3     dalle, loss_history = fit(dalle, opt, None, scheduler, 
      4                               (captions_array[train_idx, ...], all_image_codes[train_idx, ...], captions_mask[train_idx, ...]), None, 200, 256,
      5                               dalle_model_file, train_dalle_batch,

/tmp/ipykernel_19683/3286150046.py in fit(model, opt, criterion, scheduler, train_x, train_y, epochs, batch_size, model_file, trainer, n_train_samples)
     14             model.train()
     15             opt.zero_grad()
---> 16             loss = trainer(model, train_x, train_y, rnd_idx[batch_idx:(batch_idx + batch_size)], criterion)
     17             loss.backward()
     18             losses.append(loss.item())

/tmp/ipykernel_19683/2510503843.py in train_dalle_batch(vae, train_data, _, idx, __)
      1 def train_dalle_batch(vae, train_data, _, idx, __):
      2     text, image_codes, mask = train_data
----> 3     loss = dalle(text[idx, ...], image_codes[idx, ...], mask=mask[idx, ...], return_loss=True)
      4     return loss

~/python/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
...
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() got an unexpected keyword argument 'mask'

My environment is as follows:
OS: ubuntu 18.04
Python 3.9.13
Torch 1.12.0
torchvision 0.13.0

Have you observed similar issue? What solution could fix it?

@aalbersk Hi Anna, the notebook was actually made on an earlier version of DALLE-pytorch. The mask has since been removed

Let me know if daf30d0 corrects the issue!

Hey, yes, training works now.

Btw, to make a whole example work correctly, I needed to move all_image_code to cpu with all_image_codes = all_image_codes.cpu() before calculating accuracy at the end.
But now everything works. Thanks for the quick fix!

@aalbersk good to hear, and cute corgi!