tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

shapes pbs (x2)

TLESORT opened this issue · comments

Code:

import numpy as np
from nngeometry.layers import WeightNorm1d
from continuum.datasets import InMemoryDataset

classifier = WeightNorm1d(in_features=512, out_features=20) # btw Also fail with nn.Linear
random_x_data = np.random.randint(0, 255, size=(20, 512))
random_y_data = np.arange(20)
data = InMemoryDataset(random_x_data, random_y_data).to_taskset()
fisher_loader = DataLoader(data, batch_size=128, shuffle=True, num_workers=6)
fim = FIM(model=classifier,
         loader=fisher_loader,
         representation=PMatDiag,
         n_output=20,
         variant='classif_logits',
         device='cpu')

Error:
image

If I solve the pb of view by modifying the weightnorm class, I get another error:
image

(I just modified the forward function of WeightNorm1d with : )

def forward(self, input: Tensor) -> Tensor:
    input = input.view(-1, self.in_features)
    norm2 = (self.weight**2).sum(dim=1, keepdim=True) + self.eps
    return F.linear(input,
                    self.weight / torch.sqrt(norm2))

My bad it is probably from InMemoryDataset class