Error with float64 tensors
Xuzzo opened this issue · comments
Hello and thanks for your work.
Ekfac seems to have issues with models that work on double precision. Here is a code to reproduce it:
from nngeometry.metrics import FIM
from nngeometry.object import PMatEKFAC
import torch as th
dtype = th.float64
class SimpleModel(th.nn.Module):
def __init__(
self,
n_input: int,
n_output: int,
):
super().__init__()
self.fc1 = th.nn.Linear(n_input, n_output, bias=True, dtype=dtype)
def forward(self, x):
return th.nn.Softmax(dim=-1)(self.fc1(x))
if __name__ == "__main__":
model = SimpleModel(10, 3)
dataset = th.utils.data.TensorDataset(th.randn(100, 10, dtype=dtype), th.randint(0, 3, (100,), dtype=th.long))
loader = th.utils.data.DataLoader(dataset, batch_size=10)
F_ekfac = FIM(model, loader, PMatEKFAC, 3, variant='classif_logits')
F_ekfac.update_diag(loader)
I get "RuntimeError: expected scalar type Double but found Float"
Hi, thanks for pointing this out.
This PR: #71 should do it.
It still needs a little bit more testing before it gets merged to master but meanwhile you can use it.
Works now! thanks a lot