ilkerkesen / DRAW

Knet implementation of DRAW: A Recurrent Neural Network For Image Generation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

gradient problem

denizyuret opened this issue · comments

@ilkerkesen

  1. I implemented a new utility function gcheck (you need to checkout master AutoGrad and explicitly write using AutoGrad: gcheck to use it). This works with regular models and loss functions, e.g. gcheck(nll, model, x, y) where any of the inputs have param components should work.

  2. I checked this on new versions of models like lenet (from the tutorial) and it works.

  3. As expected it failed on your draw: include("debug.jl"); gcheck(loss,model,x).

  4. Then I fixed the unboxing problem replacing Knet/src/rnn.jl:130 value.(hidden) with hidden.

  5. I tried gcheck again but it still fails.

I was expecting (5) to pass. Did you try (4) and confirm the gradients work?

@denizyuret I've performed the following on CPU,

  1. gcheck with default setting defined in debug.jl: failed
  2. gcheck with only one timestep different than default setting: failed
  3. Check modular RNN: converted my own ilkarman benchmark example (Knet_RNN.ipynb) to Julia code: gcheck passed

So, currently no problem with modular RNNs. However, I use re-parametrization trick (used in VAEs), maybe this is the reason why. Right now, I'm going to try Carlo's VAE in order to see what happens.

VAE also passes. I think I need to re-digest the model. Several issues might causing harm but yes I've tried (4) and gradients are same. What did I do?

  1. Transfer PyTorch model weights to my implementation.
  2. I use just one array for randomly sampled noise to see whether I'm getting the same gradients. Use that in all the timesteps.
  3. Check gradients by eye and calculating norms.

@denizyuret here's what I've done to make gradcheck pass on this network: Use just a single noise for all the time. Then, it passes. However, I don't know why it passes on Carlo's VAE implementation. Anyway, here's what I'm going to do:

  1. Train a PyTorch network, transfer its weights to my implementation, then try to generate something meaningful.
  2. Build a mechanism for noise sampling (maybe it causes a problem to network) which is handled outside of the loss function.

I confirm that gcheck passes when we remove that value.() call and fails otherwise. You can use the latest master in order to test gcheck on my model.