lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DALLE trained on FashionGen Dataset RESULTS 💯

alexriedel1 opened this issue · comments

DALLE on FashionGen

  • I trained Dall-E + VQGAN on the FashionGen dataset (https://arxiv.org/abs/1806.08317) on Google Colab and got decent results.
  • Without the VQGAN training on the FashionGen dataset, DALLE is really bad at generating faces which makes clothing generations looking extremely strange.

Text to image generation and re-ranking by CLIP

Best 16 of 48 generations ranked by CLIP

Generations from the training set (Including their Groundtruths)

Download (5)
Download (6)
Download (7)
Download (8)
Download (4)

Generations based on custom prompts (withouttheir Groundtruths)

Download (1)
Download (2)
Download (3)
Download (9)
Download

Model specifications

VAE
Trained VQGAN for 1 epoch on Fashion-Gen dataset
Embeddings: 1024
Batch size: 5

DALLE
Trained DALLE for 1 epoch on Fashion-Gen dataset
dim = 312
text_seq_len = 80
depth = 36
heads = 12
dim_head = 64
reversible = 0
attn_types =('full', 'axial_row', 'axial_col', 'conv_like')

Optimization
Optimizer: Adam
Learning rate: 4.5e-4
Gradient Clipping: 0.5
Batch size: 7

image

Hi, can you offer the Colab link and check points?

Hi, can you offer the Colab link and check points?

You'll find the trained Dall-E weights here: https://drive.google.com/uc?id=1kEHTTZH2YbbHZjY6fTWuPb5_D-7nQ866

@alexriedel1
Thank you!
And I'm wondering which vocab you use, I only have the bpe_simple_vocab_16e6 supplied by openai

I download the weights, but it seems that it's parameters are different.
image

Yes right, the text sequence length is 120, is this a problem for you?

No, It' s just different from the description of the model.
image
I'm wondering which bpe file you use, and why the num_text_tokens are such long.

I also used the default tokenizer in this project which uses bpe_simple_vocab_16e6 byte pair encoder https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/tokenizer.py. It uses a text token size of 49408 by default.

I increased the text sequence length to 120 because the FashionGen dataset uses quite long text descriptions to the images.

Thank you a lot!