Gram matrix calculation
AlexTS1980 opened this issue · comments
I have a strange problem with Gram matrix and style loss. If I define the Gram matrix for style and generated images as :
def gram(tensor):
B, C, H, W = tensor.shape
x = tensor.view(B*C, H*W)
return 1/(2*B*C*H*W) * torch.mm(x, x.t())
B=1, C=number of feature maps
style loss very quickly converges and the resulting output is a white noise essentially OR style loss is very small to begin with and the results is the same. Only when I remove C in the denominator, style loss keeps changing throught the stylization training process. What did I do wrong? The style weight=10000, content weight=5, Adam learning rate=1. I tried slightly different parameters, but the problem remains.
Generally, the resulting (gram matrix) x transpose(gram matrix)
should be of shape C by C. Nonetheless, since B = 1, you may ignore that issue.
Unto to the problem, I think you'll need to use much much bigger style weights (~ 100 million).
But then again, easiest solution would be drop the 1/(2*B*C*H*W)
and use smaller values of STYLE_WEIGHT.
No it turned out I incorrectly loaded the weights.