borisdayma / dalle-mini

DALL·E Mini - Generate images from a text prompt

Home Page:https://www.craiyon.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to generate reproducible results

mandanraf opened this issue · comments

I want to be able to generate the same set of images each time I run the inference script locally on my machine for reproducibility for the same set of prompts. I am wondering if there is a way possible.
Thanks.

I just altered the seed code so I can either enter a custom seed, or use a random one.
Obviously if you use a static seed like below, all of your images will be the same.

for i in trange(max(n_predictions // jax.device_count(), 1)):
    # create a random key (default logic)
    # seed = random.randint(0, 2**32 - 1)
    
    # Use a specific seed to recreate the same results
    seed = 1618614716
    
    key = jax.random.PRNGKey(seed)
    
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(key),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )

You can also append the seed # to the image name so you will know what seed was used.

    img.save(os.path.join(outdir, f"{base_count:04}-{seed}.png"))

Also keep in mind you'll have to use the same batch size (i.e., number of prompts in the list you pass to processor()) and number of GPUs/TPUs (shard_prng_key() splits the key by jax.device_count()) or you'll get divergent random number sequences.

You can make those sequences independent from batch size by giving each item its own seed (e.g. either patch generate to use a jax.vmap()-ed version of jax.random.categorical()/jax.random.split(), or jax.vmap() the key parameter on the outside and add an empty dimension to your tokenized prompts)... but due to numerical instability, you won't reliably get the same results.

I suspect this might make results irreproducible between different hardware and/or library/driver versions, but I haven't tested that.

If anyone figures out how to avoid that instability, I'd be very interested!

I actually tried all the above suggestions but nothing keeps me from reproducing the exact same results :( Using Cpu would make it reproducible but not gpu.

Oh I didn't try it but I would expect that if you run the notebook twice with the same seed value you would get the same images, no?

I just reverted my changes to test the vanilla implementation, and it seems to be working as expected. Though I'm not sure the seed is ever logged anywhere for the user?

It's only if you want 1 seed per batch that my implementation helps anyway.