jhjacobsen / invertible-resnet

Official Code for Invertible Residual Networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Classifier OOM when computing on test set.

AlexanderMath opened this issue · comments

Thanks for a great repository, the code works very well, is nicely documented and the overall structure is intuitive.
I found a minor issue which can easily be solved.

The function test(..) computes loss on test set without turning gradient computations off.

https://github.com/jhjacobsen/invertible-resnet/blob/master/models/utils_cifar.py#L194
image

One might think model.eval() turns off gradients, but it does not, see e.g. [1].
Instead, one needs something like

model.eval()
with torch.no_grad(): 
    # code from before 

This does usually not cause OOM, but if one is training multiple classifiers at the same time on the same GPU it does.
This is useful when e.g. repeating experiment to get error bars).

[1] https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615

Thank you, will add a fix shortly.