tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error with float64 tensors

Xuzzo opened this issue · comments

Hello and thanks for your work.
Ekfac seems to have issues with models that work on double precision. Here is a code to reproduce it:

from nngeometry.metrics import FIM
from nngeometry.object import PMatEKFAC
import torch as th

dtype = th.float64

class SimpleModel(th.nn.Module):

    def __init__(
        self,
        n_input: int,
        n_output: int,
    ):
        super().__init__()
        self.fc1 = th.nn.Linear(n_input, n_output, bias=True, dtype=dtype)

    def forward(self, x):
        return th.nn.Softmax(dim=-1)(self.fc1(x))
    


if __name__ == "__main__":
    model = SimpleModel(10, 3)
    dataset = th.utils.data.TensorDataset(th.randn(100, 10, dtype=dtype), th.randint(0, 3, (100,), dtype=th.long))
    loader = th.utils.data.DataLoader(dataset, batch_size=10)
    F_ekfac = FIM(model, loader, PMatEKFAC, 3, variant='classif_logits')
    F_ekfac.update_diag(loader)

I get "RuntimeError: expected scalar type Double but found Float"

Hi, thanks for pointing this out.

This PR: #71 should do it.

It still needs a little bit more testing before it gets merged to master but meanwhile you can use it.

Works now! thanks a lot