n-gao / pytorch-kfac

Pytorch implementation of KFAC - this is a port of https://github.com/tensorflow/kfac/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Precndition w.r.t. 1 loss, gradient w.r.t another

ConstantinPuiu opened this issue · comments

Can your code be used to compute the KFAC (Fisher) matrix using a different loss than the loss we take the gradient of?

If so, how?

thanks

Hi, sorry for the late reply.
You may use a different loss for KFAC and a different one for the gradients. For the KFAC loss, you should pick the probability distribution of your model's output while the loss may be any loss.

In the training loop from the MNIST example it may look like:

model = Classifier().cuda()
optim = KFAC(model, 9e-3, 1e-3, momentum_type='regular', momentum=0.95, adapt_damping=True, update_cov_manually=True)
model_logprob = nn.CrossEntropyLoss(reduction='mean')
loss_fn = <your loss here>

kfac_losses = []
with tqdm.tqdm(train_loader) as progress:
    for inp, labels in progress:
        inp, labels = inp.cuda(), labels.cuda()
        model.zero_grad()
        # Estimate with model distribution
        with optim.track_forward():
            out = model(inp)
            out_samples = torch.multinomial(torch.softmax(out.detach(), 1), 1).reshape(out.shape[0])
            loss = model_logprob(out, out_samples)
        with optim.track_backward():
            loss.backward()
        optim.update_cov()
        # Compute loss to backprop
        model.zero_grad()
        out = model(inp)
        loss = loss_fn(out, labels)
        loss.backward()
        optim.step(loss=loss)
        progress.set_postfix({
            'loss': loss.item(),
            'damping': optim.damping.item()
        })
        kfac_losses.append(loss.item())