zhilin007 / FFA-Net

FFA-Net: Feature Fusion Attention Network for Single Image Dehazing

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Incompatible tensor shape in FFA.py?

charles92 opened this issue · comments

Hi,

Thanks for open-sourcing your work. I was trying to write a TensorFlow implementation of your model, and ran into the following issue.

In FFA.py, you have the code below at L96:

w=w.view(-1,self.gps,self.dim)[:,:,:,None,None]
out=w[:,0,::]*res1+w[:,1,::]*res2+w[:,2,::]*res3

Questions:

  1. After the first line, tensor w will have shape [B, gps, dim, 1, 1], where B is the batch size. Is that correct?
  2. Subsequently, in the second line, w[:, 0, ::] has shape [B, dim, 1, 1]. Correct?
  3. But res1 has shape [B, H, W, dim], where H and W are the spatial dimensions. How could you multiply w[:, 0, ::] and res1 which have incompatible shapes?

Sorry if this is trivial - I normally write TensorFlow and am quite unfamiliar with the specifics about PyTorch. Thanks!

I asked a few Torch experts and they pointed out that Torch uses [B, C, H, W] convention for their Conv2d outputs. Hence, res1 has shape [B, dim, H, W], which matches the dimensions of w. There's no issue with the Torch code.