NVlabs / FUNIT

Translate images to unseen domains in the test time with few example images.

Home Page:https://nvlabs.github.io/FUNIT/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BatchNorm layers cause training error when track_running_stats=True with DistributedDataParallel

antopost opened this issue · comments

When using DDP (pytorch 12.1) some of my batch norm layers cause the training to fail due to an inplace operation with the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: torch.cuda.FloatTensor [65]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The operation that failed was simply:

class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2, **bn_params)
        self.act = nn.SiLU(inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)    # <------------------------------------------------------------- this fails
        x = self.act(x)
        return x

On a whim I tried passing:
self.bn = nn.BatchNorm2d(c2, track_running_stats=False, **bn_params)
to all my batch norm layers and the training ran, but of course this is not a viable solution.

For the record, I also tried cloning x and setting nn.SiLU(inplace=False) but got the same error.