The cross-stitch implementation
ToTheBeginning opened this issue · comments
Hi, can you share the code of Cross-Stitch(XS)?
Sure
class CrossStitchBlock(nn.Module):
def __init__(self, K):
super(CrossStitchBlock, self).__init__()
self.K = K
alphas = []
for i in range(self.K):
for j in range(self.K):
v = nn.Parameter(torch.randn(1, 1, 1, 1), requires_grad=True)
alphas.append(v)
self.alphas = nn.ParameterList(alphas)
def forward(self, x):
x_out = []
for i in range(self.K):
z = 0
for j in range(self.K):
z = z + x[j] * self.alphas[self.K * i + j]
x_out.append(z)
return x_out
Thanks.