tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

What is the use of classif_logits variant?

zengjie617789 opened this issue · comments

I read the code carefully, and I find the details of the implement of classif_logits variant is confusing. From my understanding, the log_softmax is aim to calculate the gradient faster and restore it through exp func. But what does it mean by returning (log_probs * probs**.5) which seems like a derivavie?
Here are code pieces below:

def function_fim(*d):
      log_probs = torch.log_softmax(function(*d), dim=1)
      probs = torch.exp(log_probs).detach()
      return (log_probs * probs**.5)

Furethermore, there are 'classif_logit' and 'regression' kinds of varient, what about a output combined with regression and classif_logit? As far as i am concerned, it should calculate each output with related mode?
I will be appreciate if anyone who can help me, thank you in advance.

Here is the formula for the FIM when function f is used for classification, and uses a softmax:

formula

Since here p(y) is a discrete probability (a multinoulli) then this simplifies in:

formula

where p(y) = f(x)_y.

For the other question, I am not sure how you would combine classification and regression in a single function ?