ltrottier / deep-collaboration-network

Deep Collaboration Network in pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The cross-stitch implementation

ToTheBeginning opened this issue · comments

Hi, can you share the code of Cross-Stitch(XS)?

Sure

 class CrossStitchBlock(nn.Module):
     def __init__(self, K):
         super(CrossStitchBlock, self).__init__()
         self.K = K                                                                                                   
         alphas = []
         for i in range(self.K):
             for j in range(self.K):
                 v = nn.Parameter(torch.randn(1, 1, 1, 1), requires_grad=True)
                 alphas.append(v)
         self.alphas = nn.ParameterList(alphas)
 
     def forward(self, x):
         x_out = []
         for i in range(self.K):
             z = 0
             for j in range(self.K):
                 z = z + x[j] * self.alphas[self.K * i + j]
             x_out.append(z)
 
         return x_out

Thanks.