tfjgeorge / nngeometry

{KFAC,EKFAC,Diagonal,Implicit} Fisher Matrices and finite width NTKs in PyTorch

Home Page:https://nngeometry.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to compute FIM with nn.DataParallel(model)?

ZIYU-DEEP opened this issue · comments

Hey Thomas –

Thank you for creating this terrific package! I am wondering what should be the correct way to compute FIM when we use multiple GPUs with nn.DataParallel() to load the network.

Specifically, I encountered an KeyError when I tried to run with 3 GPUs and wrap my network with nn.DataParallel(). Below is a simplified sample for my code:

# Create model instance
class MNISTLeNet(nn.Module):
    def __init__(self):
        super(MNISTLeNet, self).__init__()
        self.cnn_model = nn.Sequential( nn.Conv2d(1,6,5), nn.ReLU(), nn.AvgPool2d(2, stride=2),    
            nn.Conv2d(6, 16, 5),  nn.ReLU(), nn.AvgPool2d(2, stride=2) )

        self.fc_model = nn.Sequential(nn.Linear(256, 120), nn.ReLU(), nn.Linear(120, 84),
            nn.ReLU(), nn.Linear(84, 10))

    def forward(self,x):
        x = self.cnn_model(x)
        x = x.view(x.size(0), - 1)
        x = self.fc_model(x)
        return x

# Parallelize the model
model = MNISTLeNet()
model = torch.nn.DataParallel(model).to(device)

# Calculate only linear and Conv2d layers
layer_collection = LayerCollection()
for layer in model.modules():
        if type(layer) in (nn.Linear, nn.Conv2d):
            layer_collection.add_layer_from_model(model, layer)

# Get the Fisher Information Matrix
F_kfac = FIM(layer_collection=layer_collection,
             model=model,
             loader=test_loader,
             representation=PMatKFAC,
             n_output=10,
             variant='classif_logits',
             device='cuda')

And I got the following error message:

  File "/home/-/github/project/helper/utils.py", line 358, in get_fisher
    F_kfac = FIM(layer_collection=layer_collection,
  File "/home/-/.local/lib/python3.8/site-packages/nngeometry/metrics.py", line 169, in FIM
    return representation(generator=generator, examples=loader)
  File "/home/-/.local/lib/python3.8/site-packages/nngeometry/object/pspace.py", line 436, in __init__
    self.data = generator.get_kfac_blocks(examples)
  File "/home/-/.local/lib/python3.8/site-packages/nngeometry/generator/jacobian/__init__.py", line 249, in get_kfac_blocks
    torch.autograd.grad(output[self.i_output], [inputs],
  File "/home/-/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 275, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/-/.local/lib/python3.8/site-packages/nngeometry/generator/jacobian/__init__.py", line 629, in <lambda>
    o.register_hook(lambda g_o: hook_gy(mod, g_o))
  File "/home/-/.local/lib/python3.8/site-packages/nngeometry/generator/jacobian/__init__.py", line 682, in _hook_compute_kfac_blocks
    layer_id = self.m_to_l[mod]
KeyError: Linear(in_features=84, out_features=10, bias=True)

The error seems not to be raised when using only one GPU.

Would you have any idea how to efficiently solve the issue and compute FIM with multiple GPUs? Thank you so much! : )

Ah, I just realized the cause for the KeyError is that torch.nn.DataParallel(model) will wrap a module outside the model. A simple fix to avoid this error would be to add one line of model = model.module before creating layer_collection and calculating FIM.

I wonder if there is any other potential concerns when we use torch.nn.DataParallel(model) on multiple GPUs with NNGeometry. Thank you!

I am not sure. If you found out, please share the answer :-)