tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature request] - multiclass semantic segmentation variant for FIM_MonteCarlo

godaup opened this issue · comments

Suppose my model outputs Batch x Class x Height x Width tensor for a multiclass image segmentation task (as in here) and I want to compute the Monte Carlo FIM. Shouldn't it be possible to interpret this as a Height x Width - fold classification problem and easily adapt the fim_function?

This is how I would go:

def fim_function(*d):
    log_softmax = torch.log_softmax(function(*d), dim=1)
    s_mb, s_c, s_h, s_w = log_softmax.size()
    log_softmax = log_softmax.permute(0, 2, 3, 1).contiguous().view(s_mb * s_h * s_w, s_c)
    probabilities = torch.exp(log_softmax)
    sampled_indices = torch.multinomial(probabilities, trials,
                                        replacement=True)
    sampled_targets = torch.gather(log_softmax, 1,
                                   sampled_indices)
    sampled_targets = sampled_targets.view(s_mb, s_h * s_w, trials).sum(dim=1)
    return trials ** -.5 * sampled_targets

But I don't have a simple setup to test it here. If you can come up with the simplest model to test it, like a 2 layer MLP on MNIST, I will be able to actually test it and add it to NNGeometry.

Below you find a minimalist setup. Hope this fits and thanks for the fast reply! :)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision
from tqdm import tqdm
data_path = 'path/to/mnist/data'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# model 
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(in_features=28*28, out_features=28*28*3)
        self.final = nn.Linear(in_features=28*28*3, out_features=28*28*10)
    
    def forward(self, x):
        x = x.reshape(-1, 28*28)
        x = F.relu(self.hidden(x))
        x = F.relu(self.final(x))
        x = x.reshape(-1, 10, 28, 28)
        return x
    
model = MLP()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# data
class MNIST_Seg(torch.utils.data.Dataset):
    def __init__(self, train=True):
        super(MNIST_Seg, self).__init__()
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(0.5, 0.5)
        ])
        self.mnist = torchvision.datasets.MNIST(root=data_path, train=train, transform=transforms, download=True)
        
    def __len__(self):
        return len(self.mnist)
    
    def __getitem__(self, item):
        image, label = self.mnist.__getitem__(item)
        mask = torch.zeros(10, 28, 28)
        mask[label] = image.gt(0).to(torch.float32)
        return image, mask
        
train_loader = torch.utils.data.DataLoader(MNIST_Seg(True), batch_size=20, shuffle=True)
val_loader = torch.utils.data.DataLoader(MNIST_Seg(False), batch_size=20, shuffle=False)

# training 
model.to(device)
model.train()
for ep in range(10):
    with tqdm(total=len(train_loader), desc=f'epoch {ep + 1}') as t:
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            loss = criterion(model(images), masks)
            loss.backward()
            optimizer.step()
            t.set_postfix(loss=loss.item())
            t.update()

it is now part of the FIM_MonteCarlo helper.

I am closing the issue.