tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: Shape is invalid for input of size

xand-stapleton opened this issue · comments

I'm trying to use the latest git release of NNGeometry's FIM to find the Fisher metric of my trivial model. As a stupidly basic example which recreates my problem, I create a model which has a single Linear layer, a single training sample, and solves the matrix equation Ax=b, where A is a 3x3 matrix, whilst x, b are 3x1 col. vectors.

Here's my code (it's not meant for anything functional, it's just to replicate my problem):

import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Net, self).__init__()

        self.linear = nn.Linear(input_dim, output_dim, bias=False)

    def forward(self, x):
        out = self.linear(x)
        return out

# Define the training data
A = torch.tensor([[1., 2., 3.],
                  [4., 5., 6.],
                  [7., 8., 9.]])

b = torch.tensor([[52.],
                  [124.],
                  [196.]])
# Define the model and the optimizer
model = Net(input_dim=9, output_dim=3)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
for epoch in range(2000):
    optimizer.zero_grad()
    y_pred = model(A.view(9))
    print(A@y_pred)
    loss = nn.MSELoss(reduction='sum')(A@y_pred.view((3,1)), b)
    loss.backward()
    optimizer.step()

# Evaluate the model
with torch.no_grad():
    y_pred = model(A.reshape(9))
    print("Solution:\n", y_pred)

Now I create a simple dataloader with that single training sample in (just as a proof of concept):

from torch.utils.data import DataLoader, Dataset

class TrivialDataset(Dataset):
    def __init__(self):
        self.data = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]).reshape(1,9)
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

# Create the Dataloader
batch_size = 1
dataset = TrivialDataset()
loader = DataLoader(dataset, batch_size=batch_size)

Now if I try to find the FIM:

from nngeometry.metrics import FIM
from nngeometry.object import PMatDense

fisher_metric = FIM(model, loader, n_output=1, variant='regression', representation=PMatDense, device='cpu')

There's a runtime error:

File [~/miniconda3/envs/torch/lib/python3.10/site-packages/nngeometry/generator/jacobian/__init__.py:77](https://file+.vscode-resource.vscode-cdn.net/Users/as/Desktop/tmp/nngeometry/nngeometry-examples/display_and_timings/~/miniconda3/envs/torch/lib/python3.10/site-packages/nngeometry/generator/jacobian/__init__.py:77), in Jacobian.get_covariance_matrix(self, examples)
     75 inputs.requires_grad = True
     76 bs = inputs.size(0)
---> 77 output = self.function(*d).view(bs, self.n_output) \
     78     .sum(dim=0)
     79 for i in range(self.n_output):
     80     self.grads.zero_()

RuntimeError: shape '[9, 1]' is invalid for input of size 3

I think this comes about because FIM is trying to reshape the output based on the input size. Is this correct?

Thanks

You need to pass n_output=3 when instantiating your object using the FIM helper. That way, the generator will expect every minibatch example to produce an output of size 3.

Thanks for your quick reply! Unfortunately that doesn't change the error, except the 1 becomes a 3:

RuntimeError: shape '[9, 3]' is invalid for input of size 3

I think this comes from the fact that you are using a single example, instead of a minibatch of several examples. NNGeometry was designed to work with datasets with many examples.

Update: Yep, that's the problem, making the following change solves it:

class TrivialDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(9, dtype=torch.float32).view(1,1,9)
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

Thanks! :)