rasbt / deeplearning-models

A collection of various deep learning architectures, models, and tips

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question regarding gradient checkpointing

Dxyk opened this issue · comments

Hello,

I am trying to understand gradient checkpointing and found your explanation in gradient-checkpointing-nin.ipynb very helpful. I cloned the repo and tried rerunning the experiments. However, I was unable to reproduce the result mentioned in your conclusion.

When I run the notebook, for the vanilla NiN, my memory consumption (current, peak) are 413527 and 154049604, with runtime 109.1s.
For the checkpointed version (segments=1) of the model, the memory consumption are 402938 and 154064699, with runtime 110.14s.
From these tests, I was not able to observe a significant improvement in memory as the notebook states (22% memory improvement with 14% runtime sacrifice).

I've tried running with multiple seeds and checkpoint segment sizes, and was not able to see a significant memory improvement either.

I'm not sure why this is and could need a bit of help. Could this be due to the size of the network is relatively small and the effects are less obvious? Or could it be the checkpointing implementation from PyTorch has changed over the years? I would appreciate it if you could provide any insight in this.