tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Make differentiating inputs optional

Xuzzo opened this issue · comments

Hello,

I am calculating the EK-FAC representation on the linear layers of a model with an embedding layer. The inputs are integers which are lookup values and therefore should not be differentiated through. Within the calculation of the kfac blocks in the Jacobian class the inputs are labelled as requiring grad, which gives me RuntimeError: only Tensors of floating point and complex dtype can require gradients. As far as I can tell, this is just so that the hooks can be triggered on all the layers, so a simple backward call would be enough.

Here is a simple code to replicate the error

from torch.nn import Embedding, Linear

from nngeometry.metrics import FIM
from nngeometry.object import PMatEKFAC
import torch as th
from nngeometry.layercollection import LayerCollection


class EmbeddingModel(th.nn.Module):

    def __init__(
        self,
        n_input: int,
        n_output: int,
    ):
        super().__init__()
        self.embedding = Embedding(n_input, 10)
        self.fc1 = Linear(10, n_output)

    def forward(self, x):
        return self.fc1(self.embedding(x))

if __name__ == "__main__":
    model = EmbeddingModel(10, 1)
    active_layers = LayerCollection()
    active_layers.add_layer_from_model(model, model.fc1)
    dataset = th.utils.data.TensorDataset(th.randint(1, 1000, (100,)), th.randint(0, 1, (100,)))
    loader = th.utils.data.DataLoader(dataset, batch_size=10)

    F_ekfac = FIM(model, loader, PMatEKFAC, 1, variant='classif_logits', layer_collection=active_layers)

Would it be possible to remove the inputs.requires_grad = True or make it optional?
Thanks a lot for your help and for your very instructive library.

Hi, as you correctly guessed requires_grad=True is mandatory for the hooks to be triggered. Not using xx.backward() was a design choice some time ago and I do not remember the exact reason. I see at least one drawback in that the call to .backward updates the .grad of all tensors, which can interfere in some situations, e.g. in the case of natural gradient you both want to compute the FIM and the average gradients of the loss.

I am afraid I don't have a satisfactory answer. Perhaps you could replace your embedding layer by a Linear layer and transform your input by onehot vectors ?

I have done a few tests and for simple linear and conv models using backward or autograd.grad gives the same results. I think that if differentiating inputs is necessary for some applications it could be made optional in the library. This way one would avoid computational overhead and/or errors like the one I had.

Anyway, I will try to find a workaround (onehot + linear should do the job =) ). Thanks for your answer and for your work.