Why am I getting valid=False in a simple Double Convolution?
mlagunas opened this issue · comments
Manuel commented
Hi!
Thanks for this great repo. :)
I was giving a try to the validation callbacks and methods and found that I would always get valid=False
for a very simple model like that one:
from verification.batch_gradient import BatchGradientVerification
class FakePLSystem(pl.LightningModule):
def __init__(self):
super(FakePLSystem, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU()
)
def forward(self, x):
return self.model(x)
model = FakePLSystem()
verification = BatchGradientVerification(model)
valid = verification.check(input_array=torch.randn(2, 3, 512, 512))
print(valid)
Do you have a clue on why this is happening?
Arman Ali Mohammadi commented
Hi
had the same problem. it seems that BatchNorm layers cause this issue. comment them to see if that's the reason.
I've submitted a bug report here #3.