Sequential backprop impl sketch
vadimkantorov opened this issue · comments
Vadim Kantorov commented
Should something like below work for wrapping ResNet's last layer (Neck
)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)
import torch
import torch.nn as nn
class SequentialBackprop(nn.Module):
def __init__(self, module, batch_size = 1):
super().__init__()
self.module = module
self.batch_size = batch_size
def forward(self, x):
y = self.module(x.detach())
return self.Function.apply(x, y, self.batch_size, self.module)
class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y, batch_size, module):
ctx.save_for_backward(x)
ctx.batch_size = batch_size
ctx.module = module
return y
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
grads = []
for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
with torch.enable_grad():
x_mini = x_mini.detach().requires_grad_()
x_mini.retain_grad()
y_mini = ctx.module(x_mini)
torch.autograd.backward(y_mini, g_mini)
grads.append(x_mini.grad)
return torch.cat(grads), None, None, None
if __name__ == '__main__':
backbone = nn.Linear(3, 6)
neck = nn.Linear(6, 12)
head = nn.Linear(12, 1)
model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)
print('before', neck.weight.grad)
x = torch.rand(512, 3)
model(x).sum().backward()
print('after', neck.weight.grad)
ck6698000 commented
Hello vadimkantorov! I'm trying to implement this module recently, wondering whether your SBP code can work or not?
Or there may need more modification? Would be grateful if any help is provided!