Getting stuck on generate_images
m0nologuer opened this issue · comments
Sakunthala Panditharatne commented
Here's my code -- no idea what's happening
import torch
from dalle_pytorch import DiscreteVAE, DALLE
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image
BATCH_SIZE = 4
IMAGE_SIZE = 64
IMAGE_PATH = "."
EPOCHS = 1
vae = DiscreteVAE(
image_size = IMAGE_SIZE,
num_layers = 2, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
num_tokens = 1024, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
codebook_dim = 256, # codebook dimension
hidden_dim = 32, # hidden dimension
num_resnet_blocks = 1, # number of resnet blocks
temperature = 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization
straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)
##Train on images
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)
dataset = ImageFolder(
IMAGE_PATH,
T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(IMAGE_SIZE),
T.CenterCrop(IMAGE_SIZE),
T.ToTensor()
])
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
##Run training for several epochs
count = 0
for epoch in range(EPOCHS):
for (images, labels) in iter(dataloader):
loss = vae(images, return_loss = True)
loss.backward()
print(count)
count = count + 1
#Train on text to images
dalle = DALLE(
dim = 1024,
vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens
num_text_tokens = 1000, # vocab size for text
text_seq_len = 16, # text sequence length
depth = 12, # should aim to be 64
heads = 16, # attention heads
dim_head = 64, # attention head dimension
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1 # feedforward dropout
)
text = torch.randint(0, 1000, (BATCH_SIZE, 16))
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)
loss = dalle(text, images, return_loss = True)
loss.backward()
# do the above for a long time with a lot of data ... then
images = dalle.generate_images(text)
img1 = images[0]
save_image(img1, 'img1.png')
print(images.shape) # (4, 3, 256, 256)