Classifier OOM when computing on test set.
AlexanderMath opened this issue · comments
Thanks for a great repository, the code works very well, is nicely documented and the overall structure is intuitive.
I found a minor issue which can easily be solved.
The function test(..)
computes loss on test set without turning gradient computations off.
https://github.com/jhjacobsen/invertible-resnet/blob/master/models/utils_cifar.py#L194
One might think model.eval()
turns off gradients, but it does not, see e.g. [1].
Instead, one needs something like
model.eval()
with torch.no_grad():
# code from before
This does usually not cause OOM, but if one is training multiple classifiers at the same time on the same GPU it does.
This is useful when e.g. repeating experiment to get error bars).
[1] https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615
Thank you, will add a fix shortly.