[suggestion] Ignore unsupported modules option
fredguth opened this issue · comments
Is it possible to just ignore not supported models (like nn.module.BatchNorm2d)? I want to use nngeometry with timm models.
Hi, thanks for reaching out
BatchNorm2d is actually supported with PMat(Block)Diagonal or PMatDense. I am guessing that you meant to use it with (E)KFAC ? The issue is that it does not really make sense to use KFAC on these layers, and it is pointless since they typically have fewer parameters than convolution/fully connected layers. In that case the simplest way of ignoring some layers is to manually create a LayerCollection, and then add relevant layers only, ignoring layers for which KFAC doesn't make sense. https://nngeometry.readthedocs.io/en/latest/api/layercollection.html
I chose to raise an Exception instead of silently failing in order to not mislead users into believing that they are computing the FIM for all layers.
In the future, when I have time, I was planning on implementing a mixed representation with KFAC for supported layers and PMatDiag or BlockDiag otherwise.
Hope this helps!
I can use a layer_collection instead of a model? I didn't know that.
Thanks!