How to calculate ssim between channels in a tensor ?
freesunshine opened this issue · comments
The output of our CNN network is a non-negative tensor named D which dimension is [B,4,H,W]. B is batch size. For every sample, the output is a [4,H,W] tensor named Di. We want minimize the image structure similarity between channels of Di, so we define a custom loss function using SSIM. We calculate the SSIM value of each channel to the others , and take the sum as the final loss.
In the beginning, we did not concern about the different of value distribution between each channel, and the code is :
criterionSSIM = ssim.SSIM(data_range=1, channel=4) //Construct the SSIM criterion
T1 = D.clone().detach()
l1 = T1[:, 0, :, :]
l2 = T1[:, 1, :, :]
l3 = T1[:, 2, :, :]
l4 = T1[:, 3, :, :]
tmp1 = torch.stack([l2, l3, l4, l1], 1)
loss1 = criterionSSIM(fusion_out, tmp1)
tmp2 = torch.stack([l3, l4, l1, l2], 1)
loss2 = criterionSSIM(fusion_out, tmp2)
tmp3 = torch.stack([l4, l1, l2, l3], 1)
loss3 = criterionSSIM(fusion_out, tmp3)
lossSSIM = (loss1+loss2+loss3)
But we found that the SSIM loss go down below zero quickly. To avoid negative SSIM, we normalize every channel of Di to [0, 1], and the code changes to :
criterionSSIM = ssim.SSIM(data_range=1, channel=4) //Construct the SSIM criterion
B, C, H, W = D.shape
for b in range(0, B):
for c in range(0, C):
D[b][c] = D[b][c] / torch.max(D[b][c]) // normalize every channel to [0, 1]
T1 = D.clone().detach()
l1 = T1[:, 0, :, :]
l2 = T1[:, 1, :, :]
l3 = T1[:, 2, :, :]
l4 = T1[:, 3, :, :]
tmp1 = torch.stack([l2, l3, l4, l1], 1)
loss1 = criterionSSIM(fusion_out, tmp1)
tmp2 = torch.stack([l3, l4, l1, l2], 1)
loss2 = criterionSSIM(fusion_out, tmp2)
tmp3 = torch.stack([l4, l1, l2, l3], 1)
loss3 = criterionSSIM(fusion_out, tmp3)
lossSSIM = (loss1+loss2+loss3)
Then the complier reports:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [224, 224]], which is output 0 of SelectBackward, is at version 128; expected version 127 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
We think this error is caused by the normalization action:
for b in range(0, B):
for c in range(0, C):
D[b][c] = D[b][c] / torch.max(D[b][c]) // normalize every channel to [0, 1]
But as a rookie, we don’t know how to fix it. I checked out#6934but got no clue. If anybody here can help us, that will be very appreciated and thankful.
Hi @freesunshine, are you going to make the difference between channels as large as possible?
I'm not sure about that. But I think it's a little difficult to use ssim to maximum the difference between channels. Because even random noises will lead to a very small ssim response.
You are right. And other criterions should be needed to prevent the network from becoming a noise maker.
However,it is just an idea, i'm not sure it will work or not. We will see it.
Thank you for your SSIM approach and your suggestion. : )