model.init
yuan243212790 opened this issue · comments
yuan243212790 commented
i can not find your model init code? can any body tell me?thanks
cfzd commented
Lines 65 to 117 in aa5fb63
class MultiSpectralDCTLayer(nn.Module): | |
""" | |
Generate dct filters | |
""" | |
def __init__(self, height, width, mapper_x, mapper_y, channel): | |
super(MultiSpectralDCTLayer, self).__init__() | |
assert len(mapper_x) == len(mapper_y) | |
assert channel % len(mapper_x) == 0 | |
self.num_freq = len(mapper_x) | |
# fixed DCT init | |
self.register_buffer('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel)) | |
# fixed random init | |
# self.register_buffer('weight', torch.rand(channel, height, width)) | |
# learnable DCT init | |
# self.register_parameter('weight', self.get_dct_filter(height, width, mapper_x, mapper_y, channel)) | |
# learnable random init | |
# self.register_parameter('weight', torch.rand(channel, height, width)) | |
# num_freq, h, w | |
def forward(self, x): | |
assert len(x.shape) == 4, 'x must been 4 dimensions, but got ' + str(len(x.shape)) | |
# n, c, h, w = x.shape | |
x = x * self.weight | |
result = torch.sum(x, dim=[2,3]) | |
return result | |
def build_filter(self, pos, freq, POS): | |
result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS) | |
if freq == 0: | |
return result | |
else: | |
return result * math.sqrt(2) | |
def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel): | |
dct_filter = torch.zeros(channel, tile_size_x, tile_size_y) | |
c_part = channel // len(mapper_x) | |
for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)): | |
for t_x in range(tile_size_x): | |
for t_y in range(tile_size_y): | |
dct_filter[i * c_part: (i+1)*c_part, t_x, t_y] = self.build_filter(t_x, u_x, tile_size_x) * self.build_filter(t_y, v_y, tile_size_y) | |
return dct_filter |