awaelchli / pytorch-lightning-snippets

A collection of code snippets for my PyTorch Lightning projects

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Why am I getting valid=False in a simple Double Convolution?

mlagunas opened this issue · comments

Hi!

Thanks for this great repo. :)

I was giving a try to the validation callbacks and methods and found that I would always get valid=False for a very simple model like that one:

from verification.batch_gradient import BatchGradientVerification

class FakePLSystem(pl.LightningModule):
    def __init__(self):
        super(FakePLSystem, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
    def forward(self, x):
        return self.model(x)

model = FakePLSystem()
verification = BatchGradientVerification(model)
valid = verification.check(input_array=torch.randn(2, 3, 512, 512))
print(valid)

Do you have a clue on why this is happening?

Hi

had the same problem. it seems that BatchNorm layers cause this issue. comment them to see if that's the reason.
I've submitted a bug report here #3.