tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Extracting Eigenvalues of Fisher using KFAC Representation

a-cowlagi opened this issue · comments

I am trying to get the eigenspectrum of the Fisher of my neural network using the compute_eigendecomposition() and get_eigendecomposition() methods in the KFAC implementation, but I am having trouble intepreting the returned dictionary.

If I just want to get the sorted eigenvalues of my Fisher as a flat tensor, what is the best way I should go about doing this using NNGeometry? Would getting the eigenvalues of the dense_tensor be sufficient? Also, the torch.symeig function used in the eigendecomposition calculation seem to be deprecated and torch suggests using torch.linalg.eigh.

In general, if you can afford the compute needed to calculate the spectrum of the Fisher using a PMatDense, then go for it.

In practice, on actual neural networks, PMatDense objects cannot be computed since their size grows in O(d^2), d being the number of parameters, and their eigendecomposition grows in O(d^3). I am guessing that this is the reason why you are using PMatKFAC instead?

In PMatKFAC, the Fisher is approximated:
1/ using a block matrix instead of a dense one (each block corresponds to parameters of a single layer)
2/ each block is approximated using a kronecker product: block = A kron B

In that case, you get the (approximate) full spectrum of the Fisher of each block by multiplying every eigenvalue of A with every eigenvalue of B. This is just a property of kronecker products (see e.g. in the matrix cookbook). It is not sorted by decreasing eigenvalue. Then the spectrum of the full matrix is just the concatenation of the eigenvalues of each blocks.

Thanks for the suggestion regarding torch.linalg.eigh, I will have a look at it.

Thomas