Add custom module to object detection pipeline and train jointly
ashnair1 opened this issue · comments
Ashwin Nair commented
I wanted to implement the non local module in the mask rcnn pipeline. I've already written it as a standalone module but I'm not sure how to include it in the detectron codebase so that I can jointly train the module and the model. Could someone give me a clue as to how to include it? I simply need to train an additional 4 convolutional layers (theta, phi, g, conv) but I'm unclear as to how to do it.
Here's the non local module:
Here's the non local module:
class NonLocalBlock(nn.Module):
def __init__(self,X):
super(NonLocalBlock, self).__init__()
channels = X.shape[1] # Torch tensor => [batch_size, number_of_kernels, w, h].
self.conv = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=(1,1))
# Embeddings
self.theta = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=(1,1))
self.phi = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=(1,1))
self.phi_pool = nn.MaxPool2d(2, 2)
self.g = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=(1,1))
self.g_pool = nn.MaxPool2d(2, 2)
def forward(self, x):
# Assume input is [N, F, H, W]
theta = F.relu(self.theta(x)) # Shape = [N, F, H, W]
phi = F.relu(self.phi(x)) # Shape = [N, F, H, W]
phi = self.phi_pool(phi) # Shape = [N, F, H/2, W/2]
g = F.relu(self.g(x)) # Shape = [N, F, H, W]
g = self.g_pool(g) # Shape = [N, F, H/2, W/2]
# Reshape theta, phi and g
theta = theta.permute(0,2,3,1) # Shape = [N, H, W, F]
theta = theta.reshape(-1,theta.shape[-1]) # Shape = [NHW, F]
phi = phi.permute(1,2,3,0) # Shape = [F, H/2, W/2, N]
phi = phi.reshape(phi.shape[0],-1) # Shape = [F, N*H/2*W/2]
g = g.permute(0,2,3,1) # Shape = [N, H/2, W/2, F]
g = g.reshape(-1,g.shape[-1]) # Shape = [N*H/2*W/2, F]
# Matrix Multiplication 1
prod = torch.matmul(theta,phi) # Shape = [NHW, N*H/2*W/2]
softmax = nn.Softmax(dim=0)
prod = softmax(prod)
# Matrix Multiplication 2
prod = torch.matmul(prod,g) # Shape = [NHW, F]
prod = prod.reshape(x.shape[0], # Shape = [N, H, W, F]
x.shape[2],
x.shape[3],
x.shape[1])
prod = prod.permute(0,3,1,2)
out = F.relu(self.conv(prod))
#out = out.permute(0,1,2,3) # Shape = [N, F, H, W]
assert prod.shape == x.shape
out = out + x
return out
nn_local = NonLocalBlock(layer).to('cuda') # Send model to device
refined_layer = nn_local.forward(layer)