tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[suggestion] Ignore unsupported modules option

fredguth opened this issue · comments

Is it possible to just ignore not supported models (like nn.module.BatchNorm2d)? I want to use nngeometry with timm models.

Hi, thanks for reaching out

BatchNorm2d is actually supported with PMat(Block)Diagonal or PMatDense. I am guessing that you meant to use it with (E)KFAC ? The issue is that it does not really make sense to use KFAC on these layers, and it is pointless since they typically have fewer parameters than convolution/fully connected layers. In that case the simplest way of ignoring some layers is to manually create a LayerCollection, and then add relevant layers only, ignoring layers for which KFAC doesn't make sense. https://nngeometry.readthedocs.io/en/latest/api/layercollection.html

I chose to raise an Exception instead of silently failing in order to not mislead users into believing that they are computing the FIM for all layers.

In the future, when I have time, I was planning on implementing a mixed representation with KFAC for supported layers and PMatDiag or BlockDiag otherwise.

Hope this helps!

I can use a layer_collection instead of a model? I didn't know that.
Thanks!