lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Custom Dalle-2 trained decoder generating random noise

rahulmoorthy19 opened this issue · comments

Hi,

Thank you for creating this repository, it is really helpful. So I trained a decoder for my custom task and was generating images using it and I found out that it was generating random noise images. A sample image generated by the model is added below-
trial

The inference code of the decoder model is as follows-
image_decoder = torch.load(image_decoder_path).cuda()
image_generated = image_decoder.sample(image_embed = image_proj, cond_scale = 2.)
The image_proj is a processed embedding from a trained CLIP.

This is the Unet decoder setting I am using while training

unet = Unet(
dim = 8,
image_embed_dim = 256,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()

decoder = Decoder(
unet = unet,
image_size = 224,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
learned_variance=False
).cuda()

The training loss is coming out to be 0.051178544054353195
Any help would be really useful...Thank You!!!

This is with respect to DALLE 2 so have opened the issue there