yhhhli / BRECQ

Pytorch implementation of BRECQ, ICLR 2021

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: `Trying to backward through the graph a second time` when setting opt_mode to fisher_diag

a1trl9 opened this issue · comments

Hi Yuhang,

Thank you for open sourcing this project.

As noted in the paper that diagonal fisher information matrix is applied to replace the pre-activation Hessian, we tried to set opt_mode to fisher_diag instead of mse for reconstruction. However, a runtime error is thrown:

File "xxxx/quant/data_utils.py", line 184, in __call__
    loss.backward()
  File "xxxx/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "xxxx/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

It seems occuring during backward to save grad:

handle = self.layer.register_backward_hook(self.data_saver)
        with torch.enable_grad():
            try:
                self.model.zero_grad()
                inputs = model_input.to(self.device)
                self.model.set_quant_state(False, False)
                out_fp = self.model(inputs)
                quantize_model_till(self.model, self.layer, self.act_quant)
                out_q = self.model(inputs)
                loss = F.kl_div(F.log_softmax(out_q, dim=1), F.softmax(out_fp, dim=1), reduction='batchmean')
                # here....
                loss.backward()
            except StopForwardException:
                pass

As indicated by the error, first backward succeeds but second fails.

We tried to create a very simple network for reproducing and the error keeps showing:

class DummyNet(nn.Module):
  def __init__(self):
      super(DummyNet, self).__init__()
      self.conv1 = nn.Conv2d(3, 32, 3, 3)
      self.conv2 = nn.Conv2d(32, 32, 3, 3)
      self.conv3 = nn.Conv2d(32, 1, 3, 3)

  def forward(self, x):
      x = self.conv1(x)
      x = self.conv2(x)
      x = self.conv3(x)
      output = F.log_softmax(x, dim=0)
      return output

recon_model function is the same as that in the main_imagenet file:

def recon_model(model: nn.Module):
        """
        Block reconstruction. For the first and last layers, we can only apply layer reconstruction.
        """
        for name, module in model.named_children():
            if isinstance(module, QuantModule):
                if module.ignore_reconstruction is True:
                    print('Ignore reconstruction of layer {}'.format(name))
                    continue
                else:
                    layer_reconstruction(qnn, module, **kwargs)
            elif isinstance(module, BaseQuantBlock):
                if module.ignore_reconstruction is True:
                    print('Ignore reconstruction of block {}'.format(name))
                    continue
                else:
                    print('Reconstruction for block {}'.format(name))
                    block_reconstruction(qnn, module, **kwargs)
            else:
                recon_model(module)

We are not quite sure why PyTorch complains here as backward only calls once in a batch... But we also noticed that after calling save_grad_data, grad would be cached for later loss calculation:

# in block_reconstruction
err = loss_func(out_quant, cur_out, cur_grad)

Is intermediate grad still available at this point since backward has already been called? In our case, even we workaround for the first error inside save_grad_data, here we would get a same one (i. e. backward twice)

Environment

Ubuntu 16.04 / Python 3.6.8 / PyTorch 1.7.1 / CUDA 10.1

Any advice would be appreciated.