NVlabs / wetectron

Weakly-supervised object detection.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Sequential backprop impl sketch

vadimkantorov opened this issue · comments

Should something like below work for wrapping ResNet's last layer (Neck)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)

import torch
import torch.nn as nn

class SequentialBackprop(nn.Module):
    def __init__(self, module, batch_size = 1):
        super().__init__()
        self.module = module
        self.batch_size = batch_size

    def forward(self, x):
        y = self.module(x.detach())
        return self.Function.apply(x, y, self.batch_size, self.module)

    class Function(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x, y, batch_size, module):
            ctx.save_for_backward(x)
            ctx.batch_size = batch_size
            ctx.module = module
            return y

        @staticmethod
        def backward(ctx, grad_output):
            (x,) = ctx.saved_tensors
            grads = []
            for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
                with torch.enable_grad():
                    x_mini = x_mini.detach().requires_grad_()
                    x_mini.retain_grad()
                    y_mini = ctx.module(x_mini)
                torch.autograd.backward(y_mini, g_mini)
                grads.append(x_mini.grad)
            return torch.cat(grads), None, None, None

if __name__ == '__main__':
    backbone = nn.Linear(3, 6)
    neck = nn.Linear(6, 12)
    head = nn.Linear(12, 1)

    model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)

    print('before', neck.weight.grad)

    x = torch.rand(512, 3)
    model(x).sum().backward()
    print('after', neck.weight.grad)

Hello vadimkantorov! I'm trying to implement this module recently, wondering whether your SBP code can work or not?
Or there may need more modification? Would be grateful if any help is provided!