shapes pbs (x2)
TLESORT opened this issue · comments
Code:
import numpy as np
from nngeometry.layers import WeightNorm1d
from continuum.datasets import InMemoryDataset
classifier = WeightNorm1d(in_features=512, out_features=20) # btw Also fail with nn.Linear
random_x_data = np.random.randint(0, 255, size=(20, 512))
random_y_data = np.arange(20)
data = InMemoryDataset(random_x_data, random_y_data).to_taskset()
fisher_loader = DataLoader(data, batch_size=128, shuffle=True, num_workers=6)
fim = FIM(model=classifier,
loader=fisher_loader,
representation=PMatDiag,
n_output=20,
variant='classif_logits',
device='cpu')
If I solve the pb of view by modifying the weightnorm class, I get another error:
(I just modified the forward function of WeightNorm1d with : )
def forward(self, input: Tensor) -> Tensor:
input = input.view(-1, self.in_features)
norm2 = (self.weight**2).sum(dim=1, keepdim=True) + self.eps
return F.linear(input,
self.weight / torch.sqrt(norm2))
My bad it is probably from InMemoryDataset class