lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Getting stuck on generate_images

m0nologuer opened this issue · comments

Here's my code -- no idea what's happening

import torch
from dalle_pytorch import DiscreteVAE, DALLE

from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

BATCH_SIZE = 4
IMAGE_SIZE = 64
IMAGE_PATH = "."
EPOCHS = 1

vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = 2,           # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 1024,        # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 256,       # codebook dimension
    hidden_dim = 32,          # hidden dimension
    num_resnet_blocks = 1,    # number of resnet blocks
    temperature = 0.9,        # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)

##Train on images
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)
dataset = ImageFolder(
    IMAGE_PATH,
    T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

##Run training for several epochs
count = 0
for epoch in range(EPOCHS):
    for (images, labels) in iter(dataloader):
        loss = vae(images, return_loss = True)
        loss.backward()
        print(count)
        count = count + 1


#Train on text to images
dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 1000,    # vocab size for text
    text_seq_len = 16,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)

text = torch.randint(0, 1000, (BATCH_SIZE, 16))
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)

loss = dalle(text, images, return_loss = True)
loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(text)
img1 = images[0]
save_image(img1, 'img1.png')

print(images.shape) # (4, 3, 256, 256)