openai / glide-text2im

GLIDE: a diffusion-based text-conditional image synthesis model

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Larger batch size to generate images in text2im.ipynb?

brijow opened this issue · comments

Hi, in the example notebook text2im.ipynb, I'm not clear on how to use a larger batch size that 1, or the recommended way to generate many images?

I'd like to play around with the model and generate several thousand images for some captions I have collected and evaluate the overall quality of results... however, I'm not clear on the best way to do this, rather than something along the lines of the psuedo-code below:

for each caption in my dataset:
      tokens = encode(caption)
      model_kwargs = {...}
      sample =  diffusion.p_sample_loop(...)
      save_sample(sample)

Would there be a faster way to do this than (more/less) following the recipe above?

commented

Check the code of the notebook at:

Not sure if this is too late to be helpful, but the following is about twice as fast as looping over a list of captions and seems to be what you want. I've cleaned it up form my own file, so I haven't had a chance to run it and there may be an error lurking somewhere. The basic idea is to replace tiling the tokens coming from a single prompt—see the multiplications by batch_size in the original code—with additional tokens.

from PIL import Image
from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    model_and_diffusion_defaults_upsampler
)
import torch
import matplotlib.pyplot as plt

has_cuda = torch.cuda.is_available()
device = torch.device('cpu' if not has_cuda else 'cuda')
# Create base glide.
options = model_and_diffusion_defaults()
options['use_fp16'] = has_cuda
options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
glide, diffusion = create_model_and_diffusion(**options)
glide.eval()
if has_cuda:
    glide.convert_to_fp16()
glide.to(device)
glide.load_state_dict(load_checkpoint('base', device))
print('total base parameters', sum(x.numel() for x in glide.parameters()))

guidance_scale = 3.0
upsample_temp = 0.997

# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = torch.cat([half, half], dim=0)
    model_out = glide(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = torch.cat([half_eps, half_eps], dim=0)
    return torch.cat([eps, rest], dim=1)

def glide_generate(prompts):
    """Returns a tensor of images where the ith image is generated from the ith prompt in [prompts].
    
    Args:
    prompts    -- list of string prompts
    """
    batch_size = len(prompts)
    tokens = [glide.tokenizer.encode(p) for p in prompts]
    tokens_and_masks = [glide.tokenizer.padded_tokens_and_mask(t, options['text_ctx']) for t in tokens]
    tokens = [t for t,_ in tokens_and_masks]
    masks = [m for _,m in tokens_and_masks]

    # Create the classifier-free guidance tokens (empty)
    full_batch_size = batch_size * 2
    uncond_tokens, uncond_mask = glide.tokenizer.padded_tokens_and_mask([], options['text_ctx'])

    # Pack the tokens together into glide kwargs.
    model_kwargs = dict(
        tokens=torch.tensor(tokens + [uncond_tokens] * batch_size, device=device),
        mask=torch.tensor(masks + [uncond_mask] * batch_size, dtype=torch.bool, device=device))

    glide.del_cache()
    samples = diffusion.p_sample_loop(
        model_fn,
        (full_batch_size, 3, options["image_size"], options["image_size"]),
        device=device,
        clip_denoised=True,
        progress=True,
        model_kwargs=model_kwargs,
        cond_fn=None,
    )[:batch_size]
    glide.del_cache()

    # Uncomment what's below to validate the function
    # scaled = ((samples + 1)*127.5).round().clamp(0,255).to(torch.uint8).cpu()
    # for s in scaled:
    #     plt.imshow(s.permute(1, 2, 0)  )
    #     plt.show()

    return (samples + 1) / 2
    
# Uncomment what's below to validate the function
# glide_generate(["a painting of a blue bird", "a painting of a red cat", "a painting of a purple apple"])