chrischute / glow

Implementation of Glow in PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

One-line change reduce memory from 11 GB to 2 GB.

AlexanderMath opened this issue · comments

The original glow code uses gradient checkpointing, a very efficient way of reducing peak memory consumption. The following single line adds gradient checkpointing in away that memory consumption from 11 GB to 2 GB. It allowed me to increased batch size from 64 to 256 with no issue. I think 512 is possible, maybe even 1024 if we use float16 for some of the layers.

st = self.nn(x_id)

def forward(self, x, ldj, reverse=False):
        x_change, x_id = x.chunk(2, dim=1)

        #st = self.nn(x_id) # change this line to the one below. 
        st = torch.utils.checkpoint.checkpoint(self.nn, x_id)
        s, t = st[:, 0::2, ...], st[:, 1::2, ...]
        s = self.scale * torch.tanh(s)

Thanks for the suggestion. I added a reference to this issue in the README. If you'd like to add support for checkpointing as a command line argument, feel free to open a pull request and I'll happily review.

Thanks for the suggestion. I added a reference to this issue in the README. If you'd like to add support for checkpointing as a command line argument, feel free to open a pull request and I'll happily review.

Will do when I get time after ICML deadline.