tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implementing BatchNorm for KFAC

a-cowlagi opened this issue · comments

Hello,

I am trying to use BatchNormalization in my network trained on CIFAR. The network has about 50,000 parameters and I want to use the KFAC representation in order to speed up computations. However, it looks like BatchNorm2D is unimplemented for KFAC. Would it be possible to add this implementation?

As a follow up, here is the network I am using:

class allcnn_t(nn.Module):
    def __init__(self, c1=16, c2= 32):
        super().__init__()
        d = 0

        def convbn(ci,co,ksz,s=1,pz=0):
            return nn.Sequential(
                nn.Conv2d(ci,co,ksz,stride=s,padding=pz),
                nn.ReLU(),
                nn.BatchNorm2d(co))
        

        self.m = nn.Sequential(
            nn.Dropout(0.2),
            convbn(3,c1,3,1,1),
            convbn(c1,c1,3,1,1),
            convbn(c1,c1,3,2,1),
            nn.Dropout(d),
            convbn(c1,c2,3,1,1),
            convbn(c2,c2,3,1,1),
            convbn(c2,c2,3,2,1),
            nn.Dropout(d),
            convbn(c2,c2,3,1,1),
            convbn(c2,c2,3,1,1),
            convbn(c2,10,1,1),
            nn.AvgPool2d(8),
            View(10))

        print('Num parameters: ', sum([p.numel() for p in self.m.parameters()]))

    def forward(self, x):
        return self.m(x)

Here is what I am using for the Fisher:

fisher = FIM_MonteCarlo(model=model.cpu(),
                loader=train_loader,
                representation=PMatBlockDiag,
                device= 'cpu')

Is there any immediately obvious way to speed up the Fisher computation, besides putting it on the GPU?

Hello,

Unfortunately it is not really clear what to do with BatchNorm layers when trying to apply KFAC:

  1. factorize batch norm parameters in some way?
  2. use the full Fisher for the block corresponding to the parameters of every batch norm layers

But instead of hard coding one of this 2 options, I prefer to leave the choice to the user.

I personally prefer 2., which can be implemented using the example in https://github.com/tfjgeorge/nngeometry/blob/master/examples/FIM%20for%20EWC.ipynb , scroll down to "KFAC and Batch norm layers". In essence, it consists in using 2 separate block diagonal FIMs, one for the batch norm parameters using PMatBlockDiag, and another one for the "standard" parameters using PMatKFAC.

For your 2nd questions, to speed up computation you can:

  • reduce the dataset size
  • use a GPU
  • change to a more efficient representation (e.g. PMatKFAC)
  • reduce the size of the layer with the most parameters, which will likely be the bottleneck

Thanks this is helpful -- I will adopt the second approach! A followup related to eigendecompositions using the KFAC representation. I understand that the eigenvalues in this representation can be found by take the product between every pair of eigenvalues associated with each of the Kronecker factors, then concatenating across layers to get the full spectrum. However, I have a question regarding the eigenvectors. I understand that the eigenvectors are given by the Kronecker product between eigenvectors of the Kronecker factors, but I am struggling to efficiently implement this for my task -- something better than simply looping over the eigenvectors. Do you have any recommendation as to how to get the eigenvectors of a given block -- I would to construct a matrix where each row/column is an eigenvector, and the rows are sorted by eigenvalue. I appreciate any help!

The link you shared seems to be broken.(https://github.com/tfjgeorge/nngeometry/blob/master/examples/FIM%20for%20EWC.ipynb ) Could you please update it?
Also, I want to ask if there is any way of using Batchnorm layer without implementing /modifying the codes by myself?