chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Switch over to Module.register_full_backward_hook

chr5tphr opened this issue · comments

The most recent version of PyTorch seems to have introduced a proper module backward hook behavior.
In favor of this new behavior, the old backward hook behavior will be deprecated.
At some point, we should use the pytorch full backward hook instead of our work-around of doing a full backward hook.

Right now, I am using the hook's attribute
stored_tensors['grad_output']
to get the intermediate attributions.

Do you think, it is possible to retain a similar behavior in the new version?
Or is it completely independent of the PyTorch bug?
Thanks

It might not work, but I'm probably not going to do the switch very soon.

But the better approach is actually to do something like the following:

def hook(module, input, output):
    module.output = output
    output.retain_grad()

with composite.context(model):
    handles = [module.register_forward_hook(hook) for module in model.modules()]
    output = model(data)
    torch.autograd.backward((output,), (torch.eye(n_outputs, device=device)[target],))
    for handle in handles:
        handle.remove()

for name, module in model.named_modules():
    print(f'{name}: {module.output.grad}')

were you must register the hook after the composite's hook have been registered.
The intermediate attributions can then be accessed as module.output.grad.

Hi,

nice, thanks! This is much better.

Best

Just as a clarification: the update now uses neither register_backward_hook, nor register_full_backward_hook, but instead registers the backward-hooks explicitly to the input and output's tensor's grad_fn. In turn, the change is minimal, without throwing the misleading deprecation warning, and the hook's stored_tensors['grad_output'] is used exactly as before.