gradient problem
denizyuret opened this issue · comments
-
I implemented a new utility function
gcheck
(you need to checkout master AutoGrad and explicitly writeusing AutoGrad: gcheck
to use it). This works with regular models and loss functions, e.g.gcheck(nll, model, x, y)
where any of the inputs have param components should work. -
I checked this on new versions of models like
lenet
(from the tutorial) and it works. -
As expected it failed on your draw:
include("debug.jl"); gcheck(loss,model,x)
. -
Then I fixed the unboxing problem replacing Knet/src/rnn.jl:130
value.(hidden)
withhidden
. -
I tried gcheck again but it still fails.
I was expecting (5) to pass. Did you try (4) and confirm the gradients work?
@denizyuret I've performed the following on CPU,
- gcheck with default setting defined in debug.jl: failed
- gcheck with only one timestep different than default setting: failed
- Check modular RNN: converted my own ilkarman benchmark example (Knet_RNN.ipynb) to Julia code: gcheck passed
So, currently no problem with modular RNNs. However, I use re-parametrization trick (used in VAEs), maybe this is the reason why. Right now, I'm going to try Carlo's VAE in order to see what happens.
VAE also passes. I think I need to re-digest the model. Several issues might causing harm but yes I've tried (4) and gradients are same. What did I do?
- Transfer PyTorch model weights to my implementation.
- I use just one array for randomly sampled noise to see whether I'm getting the same gradients. Use that in all the timesteps.
- Check gradients by eye and calculating norms.
@denizyuret here's what I've done to make gradcheck pass on this network: Use just a single noise for all the time. Then, it passes. However, I don't know why it passes on Carlo's VAE implementation. Anyway, here's what I'm going to do:
- Train a PyTorch network, transfer its weights to my implementation, then try to generate something meaningful.
- Build a mechanism for noise sampling (maybe it causes a problem to network) which is handled outside of the loss function.
I confirm that gcheck passes when we remove that value.() call and fails otherwise. You can use the latest master in order to test gcheck on my model.