awaelchli / pytorch-lightning-snippets

A collection of code snippets for my PyTorch Lightning projects

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Some thoughts + questions

TylerYep opened this issue · comments

Hey, thank you so much for writing this implementation up! It is a feature I've wanted to see in pytorch-lightning for a long time but never got the chance to get to it.

In the BatchGradient verifier, we pop the index containing the batch that we are testing for. However, I think it would be preferable to also verify that the gradient of that popped batch is in fact non-zero, since a gradient of all zeros would pass our test but would not train the network at all. My example code is here: verify.py

Finally, in my own projects, I wrote a few other verification functions, but I believe that they are already handled by lightning, could you verify this? I am referring to:

  • Issuing a warning if train() is turned on but all layers are frozen
  • Any NaN or INF value is present in gradients or any weights

If I have some spare time I will try testing this code myself, but overall looks really great! 👍

@TylerYep Thank you very much for the feedback. Saw this message only just now.. sorry!

However, I think it would be preferable to also verify that the gradient of that popped batch is in fact non-zero, since a gradient of all zeros would pass our test but would not train the network at all.

Very good observation, I will include that!
EDIT: done here: 9527fba

Issuing a warning if train() is turned on but all layers are frozen

I am not aware of such a feature in Lightning :)

Any NaN or INF value is present in gradients or any weights

Yes, but one needs to turn it on with a Trainer flag. Searching for these values every iteration can impact performance, so it is not on by default.

I also see in your verify.py, you have a function that runs all tests at once, that seems very convenient. 👍