Implementation MultiSpectralAttentionLayer in Tensorflow
mxmm2123 opened this issue · comments
Manh Nguyen commented
Hi author, thank for you great work. I'm implementing MultiSpectralAttentionLayer using Tensorflow, but I having some trouble with MultiSpectralAttentionLayer(MSA) making the trainning process quite slow, I think there was a mistake in re-implementing MSA. I cannot find alter for register_buffer
to create fixed DCT init in Tensorflow so it make problem. Can you review it?
def get_freq_indices(method):
assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32',
'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32',
'low1', 'low2', 'low4', 'low8', 'low16', 'low32']
num_freq = int(method[3:])
if 'top' in method:
all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2, 6, 1]
all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0, 5, 3]
mapper_x = all_top_indices_x[:num_freq]
mapper_y = all_top_indices_y[:num_freq]
elif 'low' in method:
all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4]
all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3]
mapper_x = all_low_indices_x[:num_freq]
mapper_y = all_low_indices_y[:num_freq]
elif 'bot' in method:
all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5, 3, 6]
all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3, 3, 3]
mapper_x = all_bot_indices_x[:num_freq]
mapper_y = all_bot_indices_y[:num_freq]
else:
raise NotImplementedError
return mapper_x, mapper_y
class MultiSpectralAttentionLayer(tf.keras.layers.Layer):
def __init__(self, channel, dct_h, dct_w, reduction=16, freq_sel_method='top16'):
super(MultiSpectralAttentionLayer, self).__init__()
self.reduction = reduction
self.dct_h = dct_h
self.dct_w = dct_w
mapper_x, mapper_y = get_freq_indices(freq_sel_method)
self.num_split = len(mapper_x)
mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]
mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
self.dct_layer = MultiSpectralDCTLayer(dct_h, dct_w, mapper_x, mapper_y, channel)
self.fc = tf.keras.Sequential([
tf.keras.layers.Dense(channel // reduction, use_bias=False),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(channel, use_bias=False),
tf.keras.layers.Activation('sigmoid')
])
def call(self, x):
n, h, w, c = x.shape
x_pooled = x
if h != self.dct_h or w != self.dct_w:
x_pooled = tf.image.resize(x, (self.dct_h, self.dct_w))
y = self.dct_layer(x_pooled)
y = self.fc(y)
y = tf.expand_dims(tf.expand_dims(y, axis=1), axis=1)
return x * y
class MultiSpectralDCTLayer(tf.keras.layers.Layer):
def __init__(self, height, width, mapper_x, mapper_y, channel):
super(MultiSpectralDCTLayer, self).__init__()
assert len(mapper_x) == len(mapper_y)
assert channel % len(mapper_x) == 0
self.num_freq = len(mapper_x)
self.height = height
self.width = width
self.mapper_x = mapper_x
self.mapper_y = mapper_y
self.channel = channel
self.weight = tf.Variable(initial_value=self.get_dct_filter(), trainable=False, name='weight') # In your model, you used self.register_buffer to create fixed DCT init and I cannot find alter in Tensorflow
def call(self, x):
x = x * self.weight
result = tf.reduce_sum(x, axis=[2, 3])
return result
def build_filter(self, pos, freq, POS):
result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
if freq == 0:
return result
else:
return result * math.sqrt(2)
def get_dct_filter(self):
dct_filter = np.zeros((self.height, self.width, self.channel))
c_part = self.channel // self.num_freq
for i, (u_x, v_y) in enumerate(zip(self.mapper_x, self.mapper_y)):
for t_x in range(self.height):
for t_y in range(self.width):
dct_filter[t_x, t_y, i * c_part: (i + 1) * c_part] = \
self.build_filter(t_x, u_x, self.height) * self.build_filter(t_y, v_y, self.width)
return tf.constant(dct_filter, dtype=tf.float32)