tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

compute FIM of partial parameters

Jiaxiangren opened this issue · comments

First, thanks for the amazing work!
I want to compute the FIM of partial parameters which means only part of whole parameters requires gradients, is that possible?

Yes, sure!

You need to specify a LayerCollection object which represents the structure of the parameter space that you are interested in analysing with your FIM. By default, FIM helpers create a LayerCollection object that includes all parameters of a model, but you can instead instantiate a LayerCollection object manually, see for instance here:

layer_collection.add_layer(*lc_full.layers.popitem())
where the LayerCollection object only comprises a single layer.

Then you need to pass that LayerCollection object to FIM