lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Adding support to images with transparency

markuschue opened this issue · comments

Hi, I'm training a VAE and DALLE with a custom dataset which must contain .png images with transparency.
What should I change in the code so that the model can learn and then generate images in this format?
I'm a little bit lost and any tip would help, thanks!

oh looks like the train vae script doesn't take care of it, I'll get around to this next week!

@alu0101130507 Hey Markus, do you want to give 1.6.0 a try? bebc280 You'll have to train your own custom DiscreteVAE with --transparent flag

Thanks a lot! Unfortunately, I'm still having some problems and when running the train_vae.py script I get the following errors:

C:\Users\Marku\AppData\Local\Programs\Python\Python38\lib\site-packages\PIL\Image.py:945: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  warnings.warn(
Traceback (most recent call last):
  File ".\train_vae.py", line 234, in <module>
    loss, recons = distr_vae(
  File "C:\Users\Marku\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Marku\Documents\DALLE-1.6.0\dalle_pytorch\dalle_pytorch.py", line 222, in forward
    img = self.norm(img)
  File "C:\Users\Marku\Documents\DALLE-1.6.0\dalle_pytorch\dalle_pytorch.py", line 189, in norm
    images.sub_(means).div_(stds)
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1

Maybe it's because the images are not getting converted to RGBA, but I couldn't find out how to fix it.

@alu0101130507 ohh yes, there were more things i did not consider (like normalization and validation in the DALL-E class)

ok! try 1.6.1 a6776c8

Yes, now it works! Just a little thing more, the TRANSPARENT constant wasn't defined in the train_dalle.py file so I added the following line:

TRANSPARENT = True if CHANNELS == 4 else False

Also, in the generate.py file there was a problem because the images were being saved as JPEG. I changed it to PNG and everything is now working pretty good, thank you so much! :)

@alu0101130507 got it, thank you Markus! feel free to close the issue once you find everything satisfactory :)