lucidrains / recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RIN with 100k steps, still noisy images?

Samin100 opened this issue · comments

I'm trying to reproduce the On the Importance of Noise Scheduling for Diffusion Models paper's benchmark on ImageNet 256 x 256 with the RIN. However, after training for around ~100k steps, the sampled images still look like noise. I've applied the fix mentioned here prior to training.

What I found interesting is that the noisy images change - it's almost as if the noise gets "smoother" after training, with a checkerboard pattern in between (possibly due to the image patches)?

Start (100 steps):
media_images_samples_101_83989238bf614c33120e

10k steps:
media_images_samples_9999_dfc00accb4ad002526df

30k steps:
media_images_samples_29997_36ffcfa0844d90919794

50k steps:
media_images_samples_49995_f556fbc2b8743d825e1b

90k steps:
media_images_samples_87062_7e6d97b7ee20825ab96b

Also, the training loss seems to be very noisy as well.
CleanShot 2023-03-02 at 18 16 28

Any pointers on how to go about debugging this would be much appreciated.

commented

What are the hyperparameters you're using? I am running into similar problems as yours on larger images: at 32x32 the model converges but higher resolutions dont converge at all. I think batch size is of critical importance to pixel-level models: in this paper they use bs=1024, and the recent simple diffusion paper uses a batch size of 2048, which the authors note is important for convergence.

I’m using batch size of 256 (this is as much as can fit into memory) with a learning rate of 2e-3 with the LAMB optimizer (this is the same optimizer and learning rate used in the paper.)

@Lamikins Could you share more details around the hyperparameters that got your 32 x 32 model to converge? Were you able to sample from it with good results?

I’m trying to figure out if the model just needs to be trained for longer or if there’s a bug in my implementation.

@Samin100 hmm, i was definitely able to make it converge for my toy oxford flowers dataset with a batch size of 32, image size 128, and under 100k steps.. i'll run some experiments on Sunday

@Samin100 could you share your network hyperparameters?

lol, maybe this is the issue 😮‍💨 , if you were using the epsilon objective

@lucidrains Here's the RIN's hyperparameters. I'm mostly trying to stick to the hyperparameters recommended in the paper for 64x64 images.

model = RIN(
    dim = 256,                  # model dimensions
    image_size = img_size,      # image size
    patch_size = 8,             # patch size
    depth = 4,                  # depth
    num_latents = 128,          # number of latents. they used 256 in the paper
    dim_latent = 768,           # can be greater than the image dimension (dim) for greater capacity
    latent_self_attn_depth = 6, # number of latent self attention blocks per recurrent step, K in the paper
).cuda()

And the diffusion wrapper:

diffusion = GaussianDiffusion(
    model,
    timesteps = 400,
    train_prob_self_cond = 0.9,
    scale = 1.,
    objective = 'eps',
    use_ddim=True,
    noise_schedule = 'linear',
).cuda()

I also just started a new training run using the latest fix from #8 on 32x32 images.

It's only 3k steps in, but the model still seems to only produce noise. This using batch size 1536 with lr 1e-4 using Adam and the above hyperparams.

Here's a sample from 100 steps in:
media_images_samples_101_3f2900bee7d5098da2af

3k steps:
media_images_samples_2828_38d8624e17de0e6a06e0

yeah that looks good, 3k steps is nothing, just keep training

@Samin100 are you seeing it now?

@Samin100 haha, i see you live in San Francisco, happy to show you in person this weekend if you can't get it working

@lucidrains Gotcha! I wasn't sure if the samples were supposed to look like that at 3k steps so I stopped the training run before going to bed. I just restarted it and will share more samples later today.

I'm also doing a second training run with the new time conditioning as mentioned in #9 while keeping all other hyperparams identical, so hopefully that's a good ablation study. I'll share the results for that as well.

And I appreciate the help! If I can't get either of these runs to work then I might have to take you up on that offer ☕

@Samin100 ok, let me know! i was just doing a run with the latent token time conditioning, but it wasn't working so hot

i've started a run with that off (latent_token_time_cond = False) just to verify it can still converge with the latest fixes

@Samin100 hey Sharif, was running some experiments and noticed that the epsilon objective indeed converges slowly

i've added the v objective from Progressive Distillation paper (Salimans et al). you should see more immediate results

@lucidrains Awesome! I'll start a new run now with it. How many steps in are you starting to see the samples change from noise with the v objective?

Also, here are a few samples from the two 32x32 training runs with the eps objective at 10k steps (batch size 1,792).

with latent_token_time_cond = False
media_images_samples_8989_53232513e8409cf79c6d

with latent_token_time_cond = True
media_images_samples_9292_67b716df531f8cbfba5a

Training loss:
CleanShot 2023-03-04 at 17 44 55

The samples looks more diffuse than earlier, but still no noteworthy changes yet. Gonna give the new v objective a try now.

@Samin100 yea, i agree that is strange. usually you see something much earlier, but i don't know how this perceiver-like architecture behaves with the noise objective. now i recall that the convergence i first saw was with the predict x0 objective. then someone raised the issue the fact that the paper used epsilon, and so i added that without testing

i think it is safe to use the predict v objective though! at least, they also used it for 'simple diffusion' with great results

using the v objective will also lend itself well to progressive distillation down the road, if you ever get a significant model trained

@Samin100 hey, thanks for opening this issue

i was digging into the ddim code this morning trying to figure out why epsilon objective samples so poorly. i think there was an issue surrounding the clipping of predicted x0 (from the predicted noise). once i rederived the predicted noise from the derived predicted x0, it seems to behave a lot better. in general, the clipping of x0 during sampling is not well documented in the literature, or even the original paper 😓

could you retry the predicted noise objective in the latest version of the package and see if it looks any better?

Hey @lucidrains @Samin100,

I just stumbled across this issue while im trying to get RINs working for video generation with my own codebase as well.
Up to this point i wasn't able to reach anything near convergence for resolutions higher than 64x64, which i thought was possibly due to bugs but pretty much matches this observation.

But so far i only experimented with the noise prediction objective and based on your comments I'm definitely going to check out v parametrization now.

Also one thing i noticed regarding the latent time embedding. Shouldn't it be latents = latents[:, :-1]? - But i'm still unsure about the advantage of the time embedding as latent token over the more traditional way.

But anyways thanks for open sourcing and sharing results - I will mention it if I should find out something in my following runs.

@DanielSiersleben ah shoot, yes you are correct

at a park with doggo, but let me see if I can fix this on my phone lol. welcome to the 21st century

@DanielSiersleben ok done, thank you 🙏

@DanielSiersleben i can eventually optimize the positional embeddings for video, if you let me know how your experiments turn out!

@lucidrains Here are the results from a 128x128 run using the v objective with latent_token_time_cond = False and dual_patchnorm = False. Seems like it's not converging yet. I just pulled the latest changes including the fix mentioned by @DanielSiersleben and will start a new training run now and report back later today.

Results from 128x128 model:

Start (100 steps):
media_images_samples_101_8cbb609fc9f8246543ad

20k steps:
media_images_samples_19998_1e064715b0f2d5dd2bfd

40k steps:
media_images_samples_42117_16efadddd5640f1cacc1

What's interesting is that the smaller 32x32 model seemed to show signs of learning at just around 3k steps in, unlike the larger 128x128 model. All things were kept constant between the two models, except the image size and changing the scale argument from 1 to 0.6 for the 128px model.
samples_3333_436aeb013610eda2fd39

@Samin100 how high is your learning rate

I see results at 1k mark when using predict v objective

@lucidrains Are you also using 128x128 images? I'm using 2e-3 lr with the LAMB optimizer for 128x128 and adjusting accordingly from this table in the paper depending on image size. What LR are you using?

CleanShot 2023-03-05 at 17 36 47

CleanShot 2023-03-05 at 17 35 46
It seems to converge faster than my other runs with ADAM at 1e-4. In this example the LAMB run used a larger batch size, but even with batch sizes equal LAMB is a lot faster.

sample-1

yup, 1k mark for predict v objective for oxford flowers

my settings are as follows:

from rin_pytorch import GaussianDiffusion, RIN, Trainer

model = RIN(
    dim = 256,                  
    image_size = 128,      
    patch_size = 8,           
    depth = 6,                 
    num_latents = 128,  
    dim_latent = 512,     
    latent_self_attn_depth = 4,
    dual_patchnorm = True
).cuda()

diffusion = GaussianDiffusion(
    model,    
    timesteps = 400,
    objective = 'v',
    train_prob_self_cond = 0.9,  
    scale = 1.                   
).cuda()

trainer = Trainer(
    diffusion,
    '/path/to/oxford-flowers',
    results_folder = './results-pred-v',
    num_samples = 16,
    train_batch_size = 4,
    gradient_accumulate_every = 8,
    train_lr = 3e-4,
    save_and_sample_every = 1000,
    train_num_steps = 700000,
    ema_decay = 0.995     
)

trainer.train()

@Samin100 i wouldn't recommend lamb. also a good paper

the only optimizer i could recommend that isn't adam would be lion, but it is still being evaluated by the public, and seems to be not general purpose

@lucidrains Interesting! And let me try a run with those hyperparams and ADAM right now.

sample-2

2k mark, yea it looks good

@lucidrains It's working! I see flowers! This is 1.5k steps in:
media_images_samples_1515_5c1fd85e2767a62e4935

Now I'm curious why my previous runs weren't working. The biggest differences I can think of

  • I was using a different dataset (Imagenet vs Oxford flowers)
  • I was using fp16 with AMP, but not anymore
  • LAMB optimizer instead of ADAM with a different learning rate
  • the run that works is using scale = 1, whereas I was using lower vals before (0.6, 0.5)

I'll share more details as I figure things out.

@Samin100 nice! yeah would be interested to know once you figure out the reason

@lucidrains Quick update! Here's 30k steps – the model seems to be learning pretty well. I wonder how long until the seams between the image patches become seamless.

My next step is to adjust the RIN hyperparams (one at a time) and then compare the best RIN approach to a simple diffusion approach, mostly to gain an intuition around the total compute required to train these models at varying scales. I'll share more results as they come in. In the meantime I'll go ahead and close this issue.

image

sounds good! do reraise an issue if you find out the cause and the bug is due to the repo

@DanielSiersleben how is predict v-space objective looking for your video training? thumbs up or down? i'll take silence meaning that you are going heads down trying not to get scooped for your paper or startup haha, as is often the case

Haha not quite, trying to get the RIN running as starting point for my Master's Thesis.

My first experiment with the v-parametrization also ran into problems but I think that was due to using AMP float16.
The gradients looked fine but I guess something with the clipping in the ddpm steps might break with AMP - but not 100% sure yet.

Running on a similar model size as you suggested, I already get promising results quite early in training on the UCF Dataset (64x64x8):
5af6333e-0a1b-4c9a-8130-6caf2ff9e4b7
Let's see where this is going. 😅

However, with model sizes comparable to the ones used in the paper for higher resolution (256x256x8 on kinetics700), I don't seem to make any progress even after 20k steps with batch size ~64 lr 5e-4:
906c42d5-7843-4cb3-830e-52d9ebf8cb39

Probably have to stick with smaller models...

@DanielSiersleben cool! How are you using the network? did you modify the positional embeddings, patch function (for temporal patches) at all, given I only built it for image use

If you are seeing a signal, I can spend some time and do an optimized RIN for video with some tricks I know from the lit

@lucidrains Here's the 128x128 model at 100k steps. It seems like the seams around the image patches don't want to go away. The model's hyperparams are the same ones here, except I'm using the linear noise schedule. Do you think it's a matter of the model needing more training to "learn" better image patches?

image

@Samin100 yea, you just need to train more. there's probably some architectural tweaks that can help with that artifact, but will stray from the paper

edit: oh, if you have the dual patchnorm turned on, try turning that off, as it isn't proven in a generative setting yet

@DanielSiersleben this would also need to be a 3d conv https://github.com/lucidrains/recurrent-interface-network-pytorch/blob/main/rin_pytorch/rin_pytorch.py#L234 are you just treating the video as a grid of images?

@lucidrains I mainly implemented the model on my own and was just comparing individual components with your implementation to make sure I didn't mess up.
Right now, I am keeping things pretty close to the pseudo code from the paper, so only the learnable positional embedding is added after patchifying and no PEG (which is from ViTs right?).

Tomorrow, I'll maybe add a 3D PEG and see how that works out.

What else would you modify for video besides positional embedding and patching?

nice! I think 3d peg will work well, and yes indeed it is from ViT lit

the other trick I would recommend is a token shift along the time dimension, before each feed forward block for the patches (not latents). bound to improve results https://arxiv.org/abs/2108.02432

Nice, that sounds really promising.
I will look into that, thank you!

commented

@lucidrains @Samin100 I was able to get these results at 101k steps
batch_size=128
v-pred

Results are decent but it's not nearly as compute efficient as convs (eg Karras implementation)

media_images_sample_101000_4dfede6a67f947c3281e

@Lamikins Those results look great! Do you mind sharing the rest of the hyperparameters for this run? Specifically curious about the patch size and noise scale used. Do you think RINs might be more efficient only at higher resolutions compared to a model with convs? I'm curious if if you have any more thoughts on this.

@lucidrains I am also running into a similar issue. I am training on 32X32 CIFAR-10 images. These are the generated images after 120K steps.

Screen Shot 2023-03-27 at 11 19 20 PM

I am using velocity prediction objective, Adam optimizer with learning rate of 1e-4., and batch size of 256. These are other relevant hyperparameters:

 model = RIN(
                dim = 128,                           
                image_size = 32,             
                patch_size = 4,              
                channels =  3, 
                dim_latent = 256,
                depth = 6,                      
                num_latents = 256,          
                latent_self_attn_depth = 4,
                latent_token_time_cond = False ,
            ).cuda()

Training loss curve:
Screen Shot 2023-03-27 at 11 25 57 PM

This particular example uses sigmoid noise schedule but I have tried other noise schedules as well namely, linear, and cosine schedules without much success. Any suggestions on how to train this in a better way?

For what it's worth, I've also seen these patchy artifacts with simple datasets.

@Lamikins do you mind sharing your hyper-parameters?