tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error: `I do not know what to do with layer Embedding(50304, 512)`

CarloNicolini opened this issue · comments

First of all great library, I've always been looking for some ways to get jacobians and fisher information matrices for my PyTorch models.
While the library is fine with my vision models based on simple convolutional networks, I find it harder to use with Huggingface pretrained models.
To be clear, I believe the embedding layers are the culprit here.

I devised a dataloader taking text and returning a dictionary with "input_ids" and "attention_mask" which takes in a list of strings as input and yields a batch like a dictionary with the above keys and torch.Tensor of integer type as their values.

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from transformers.tokenization_utils import BatchEncoding

torch_model = GPTNeoXForCausalLM.from_pretrained(
    pretrained_model_name_or_path=f"EleutherAI/pythia-70m-deduped",
    revision=f"step1000",
    cache_dir=cache_dir,
)

class FIMDataLoader(Dataset):
    def __init__(self, text_list, tokenizer, max_length=128):
        self.text_list = text_list
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.text_list)

    def __getitem__(self, idx):
        text = self.text_list[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            # max_length=self.max_length,
            return_tensors="pt",
        )
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()
        return input_ids, attention_mask


def collate_fn(batch):
    input_ids, attention_mask = zip(*batch)

    return BatchEncoding(
        {
            "input_ids": torch.stack(input_ids),
            "attention_mask": torch.stack(attention_mask),
        }
    )


def create_dataloader(text_list, tokenizer, batch_size, max_length, shuffle=False):
    dataset = FIMDataLoader(text_list, tokenizer, max_length)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn
    )
    return dataloader

then I instanciate the dataloader

texts_list = ["The cat is on the table", "Alice and Bob are friends"]
dataloader = create_dataloader(
    texts_list, tokenizer, batch_size=1, max_length=128, shuffle=False
)

For a model with a total of 70m parameters, having the entire Fisher matrix in memory is prohibitive, so I have chosen to use the diagonal with storage proportional to number of parameters, by choosing the PMatDiag representation you kindly provided in your library.

I thought this would give me the diagonal of the Fisher information matrix, right?
However, an error appears that seems related with LayerCollection creation.

from nngeometry.metrics import FIM
from nngeometry.object import PMatDiag

FIM(
    model=torch_model,
    loader=dataloader,
    representation=PMatDiag,
    n_output=1,
    device="cpu",
)

but I get the following error:

Exception                                 Traceback (most recent call last)
Cell In[93], line 6
      3 from nngeometry.metrics import FIM
      4 from nngeometry.object import PMatDiag
----> 6 F_ekfac = FIM(
      7     model=torch_model,
      8     loader=dataloader,
      9     representation=PMatDiag,
     10     n_output=1,
     11     device=\"cpu\",
     12 )

File ~/opt/miniconda3/envs/pythia/lib/python3.10/site-packages/nngeometry/metrics.py:147, in FIM(model, loader, representation, n_output, variant, device, function, layer_collection)
    144         return model(d[0].to(device))
    146 if layer_collection is None:
--> 147     layer_collection = LayerCollection.from_model(model)
    149 if variant == 'classif_logits':
    151     def function_fim(*d):

File ~/opt/miniconda3/envs/pythia/lib/python3.10/site-packages/nngeometry/layercollection.py:50, in LayerCollection.from_model(model, ignore_unsupported_layers)
     48     elif not ignore_unsupported_layers:
     49         if len(list(mod.children())) == 0 and len(list(mod.parameters())) > 0:
---> 50             raise Exception('I do not know what to do with layer ' + str(mod))
     52 return lc

Exception: I do not know what to do with layer Embedding(50304, 512)"
}
```

It looks like the reason why I get this error has to do with the Embedding layers (there are two embedding layers, one to convert token ids  from the vocabulary space (size 50304) to the latent space (size 512) and another embedding layer at the end to do viceversa.

What should I do to have the FIM diagonal of all model parameters?
Many thanks, and again, great package.

Hi, the issue here is that I have not yet found a satisfactory as to how to implement Embedding layers in NNGeometry.

A possible temporary workaround would be to emulate embeddings using a Linear layer + onehot encoding

Another possible workaround would be to compute the FIM for all parameters but the Embedding layer's one by manually creating the LayerCollection object passed to the FIM constructor, instead of the default which is to add all Pytorch modules with parameters.

Would you kindly expand on the first workaround?
I've tried to follow your suggestion and came up with this snippet, but I am not sure this is what you intended, could you please check?

import torch.nn as nn
import torch.nn.functional as F

class OneHotLinearEmbedding(nn.Module):
    def __init__(self, pretrained_embedding_layer):
        super(OneHotLinearEmbedding, self).__init__()

        # Get the input size and embedding dimension from the pretrained embedding layer
        input_size, embedding_dim = pretrained_embedding_layer.weight.size()

        # Linear layer to perform embedding
        self.embedding_layer = nn.Linear(embedding_dim, input_size , bias=False)

        # Use the pretrained weights for the embedding layer
        self.embedding_layer.weight.data.copy_(pretrained_embedding_layer.weight.data)

    def forward(self, input_indices):
        # One-hot encoding
        one_hot = F.one_hot(
            input_indices, num_classes=self.embedding_layer.in_features
        ).float()

        # Apply linear layer for embedding
        embedded = self.embedding_layer(one_hot)
        return embedded

Apparently though after I replace my input Embedding layer with this for my model, the FIM computation breaks when dealing with LayerNorm. Another layer to implement on NNGeometry?

Hi just a quick update that I should be able to find some time to fix this later this week or next week.

You are indeed right that LayerNorm is not implemented yet. After a quick glance it looks to be implemented very similarly to BatchNorm which means that it should not be too difficult to implement in NNGeometry.

Your implementation of the LinearLayer workaround of the embedding layer looks correct to me.

I will keep you updated as soon as I make progress!

Hello, for some reason, implementing LayerNorm broke the test suite for other types of layers (Cosine and WeightNorm), thus I cannot merge it to master. It additionnally requires some more cleaning.

In the meantime you can use this branch, it should do the job for your usage. Otherwise, can you provide me with the simplest architecture for which it fails ?

Best