tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How is it possible to handle a model that has a BatchNorm layer using the PMatEKFAC representation to get the FIM ?

johnrachwan123 opened this issue · comments

Currently, I get a not Implemented Exception.

I have a ResNet18 model that I would like to get the trace of the FIM for. However this model has many batchnorm layers and using the standard PMatDiag runs my 8GB CUDA GPU out of memory. How would you recommend I solve this memory issue?

Is there an easy way to modify the code in order to approximate the linear and Conv layers using PMatEKFAC and the Batchnorm layers using PMatBlockDiag ?

I think I figured it out, but I have a small additional question. How would you recommend is the fastest way to get the trace of the FIM. I am only interested in the trace.

Yes, I thought of using those but the issue is it becomes extremely slow. My purpose for calculating the trace is to find when the forgetting phase in network training is reached (from https://arxiv.org/abs/1711.08856). But if finding this phase is very expensive then it defeats the purpose (Since I find this phase to improve something in network training)

Thanks a lot for your help! I actually tried using only a subset of the training set but it seemed like the value of the trace become a lot bigger so I might have made some small mistake. All I did was break the loop that was going through the data loader and adjust the n_examples to be the same as the number of loops performed before the break. Is there anything else I should do ?

Thanks a lot for your help and thank you for this great library!