cfzd / FcaNet

FcaNet: Frequency Channel Attention Networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

频率分量的确定

yueshuheng opened this issue · comments

想请问一下,这些频率分量是怎么确定的呀?

作者在这里给出了一种两步准则来选择MCA模块中的频率分量。其主要**是为:

第一步先分别计算出通道注意力中每个频率分量的结果;
第二步再根据所得结果筛选出Top-k个性能最佳的频率分量。

可以参考这篇文章:
https://zhuanlan.zhihu.com/p/339215696

Hi, this work is amazing! However, I found that the module of the MultispectralDCTlayer is empty while debugging. May I ask what is the reason for this result.
image

commented

@Zhongrocky
It is not empty. It just don't fave any learnable parameters. The implementation of MultispectralDCTlayer is here:

FcaNet/model/layer.py

Lines 65 to 117 in aa5fb63

class MultiSpectralDCTLayer(nn.Module):
"""
Generate dct filters
"""
def __init__(self, height, width, mapper_x, mapper_y, channel):
super(MultiSpectralDCTLayer, self).__init__()
assert len(mapper_x) == len(mapper_y)
assert channel % len(mapper_x) == 0
self.num_freq = len(mapper_x)
# fixed DCT init
self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
# fixed random init
# self.register_buffer('weight', torch.rand(channel, height, width))
# learnable DCT init
# self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel))
# learnable random init
# self.register_parameter('weight', torch.rand(channel, height, width))
# num_freq, h, w
def forward(self, x):
assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape))
# n, c, h, w = x.shape
x = x * self.weight
result = torch.sum(x, dim=[2,3])
return result
def build_filter(self, pos, freq, POS):
result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
if freq == 0:
return result
else:
return result * math.sqrt(2)
def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
dct_filter = torch.zeros(channel, tile_size_x, tile_size_y)
c_part = channel // len(mapper_x)
for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
for t_x in range(tile_size_x):
for t_y in range(tile_size_y):
dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter(t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y)
return dct_filter

@cfzd Thank you a lot. Got it.