Is this a mistake in the paper? or Am i misunderstood it ?
fhahaha opened this issue · comments
Hi,
yes this is a mistake. It was already brought to my attention. The updated paper will be released on arXiv tomorrow.
first of all, thanks for your project!!!
i write a dtln-aec model depend on your code, i train this model with the AEC-challenge dataset, train.py depend on your other project: https://github.com/breizhn/DTLN .
but when i convert my trained model to tflite and inference with run_aec.py, the output audio is silence.
where are the mistakes, waiting for your help!!!
`
def stftLayer(self, x):
frames = tf.signal.frame(x, self.block_len, self.block_shift)
stft_dat = tf.signal.rfft(frames)
mag = tf.abs(stft_dat)
phase = tf.math.angle(stft_dat)
return [frames, stft_dat, mag, phase]
def create_dtln_aec(self, batch=None, norm_stft=True, stateful=False):
input_farend = tf.keras.Input(batch_shape=(batch, self.wav_dims), dtype=tf.float32, name="input_farend")
input_nearend = tf.keras.Input(batch_shape=(batch, self.wav_dims), dtype=tf.float32, name="input_nearend")
frame_farend, _, mag_farend, angle_farend = nn.Lambda(self.stftLayer)(input_farend)
_, data_nearend, mag_nearend, angle_nearend = nn.Lambda(self.stftLayer)(input_nearend)
mag_farend_sq = tf.square(mag_farend)
mag_nearend_sq = tf.square(mag_nearend)
if norm_stft:
mag_farend_norm = InstantLayerNormalization(name='magfarendnorm')(tf.math.log(mag_farend_sq + 1e-7))
mag_nearend_norm = InstantLayerNormalization(name='magnearendnorm')(tf.math.log(mag_nearend_sq + 1e-7))
else:
mag_farend_norm = mag_farend
mag_nearend_norm = mag_nearend
mag_norm = nn.Concatenate(axis=-1, name="cat1")([mag_farend_norm, mag_nearend_norm])
x = nn.LSTM(self.lstm_units, return_sequences=True, stateful=stateful, name="lstm1")(mag_norm)
x = nn.Dropout(self.dropout)(x)
x = nn.LSTM(self.lstm_units, return_sequences=True, stateful=stateful, name="lstm2")(x)
x = nn.Dense(self.fdims, activation='sigmoid', name="dense1")(x)
stft_nearend_1 = nn.Multiply()([tf.cast(data_nearend, dtype=tf.float32), x])
nearend_1 = nn.Lambda(self.ifftLayer)([stft_nearend_1, angle_nearend])
conv_layer = nn.Conv1D(512, kernel_size=1, strides=1, use_bias=False, name="conv1")
nearend_2 = conv_layer(nearend_1)
farend_2 = conv_layer(frame_farend)
if norm_stft:
nearend_2_norm = InstantLayerNormalization(name="nearend2norm")(nearend_2)
farend_2_norm = InstantLayerNormalization(name='farend2norm')(farend_2)
else:
nearend_2_norm = nearend_2
farend_2_norm = farend_2
near_far = nn.Concatenate()([nearend_2_norm, farend_2_norm])
x = nn.LSTM(self.lstm_units, return_sequences=True, stateful=stateful, name="lstm3")(near_far)
x = nn.Dropout(self.dropout)(x)
x = nn.LSTM(self.lstm_units, return_sequences=True, stateful=stateful, name="lstm4")(x)
x = nn.Dense(512, activation='sigmoid', name='dense2')(x)
x = nn.Multiply()([nearend_2, x])
x = nn.Conv1D(512, kernel_size=1, use_bias=False, name='conv2')(x)
output = nn.Lambda(self.overlapAddLayer)(x)
model = tf.keras.Model(inputs=[input_nearend, input_farend], outputs=output, name='dtln-aec')
print(model.summary())
return model
def create_dtln_aec_tflite(self, weights, norm_stft, target_name, use_dynamic_range_quant=False):
if norm_stft:
num_elements_first_core = 2 + 2 + 2*3 + 2
else:
num_elements_first_core = 2*3 + 2
model = self.create_dtln_aec(batch=1, stateful=True)
model.load_weights(weights)
# model1
mag_farend = nn.Input(batch_shape=(1, 1, self.block_len//2 + 1), name='input_mag_farend')
mag_nearend = nn.Input(batch_shape=(1, 1, self.block_len//2 + 1), name='input_mag_nearend')
states_in_1 = nn.Input(batch_shape=(1, 2, self.lstm_units, 2), name='input_states')
mag_farend_sq = tf.square(mag_farend)
mag_nearend_sq = tf.square(mag_nearend)
if norm_stft:
mag_farend_norm = InstantLayerNormalization(name='magfarendnorm')(tf.math.log(mag_farend_sq + 1e-7))
mag_nearend_norm = InstantLayerNormalization(name='magnearendnorm')(tf.math.log(mag_nearend_sq + 1e-7))
else:
mag_farend_norm = mag_farend
mag_nearend_norm = mag_nearend
mag_norm = nn.Concatenate(axis=-1, name="cat1")([mag_farend_norm, mag_nearend_norm])
states_h, states_c = [], []
in_state_1 = [states_in_1[:, 0, :, 0], states_in_1[:, 0, :, 1]]
x, h_state, c_state = nn.LSTM(self.lstm_units, return_sequences=True, unroll=True,
return_state=True, name="lstm1")(mag_norm, in_state_1)
states_h.append(h_state)
states_c.append(c_state)
x = nn.Dropout(self.dropout)(x)
in_state_2 = [states_in_1[:, 1, :, 0], states_in_1[:, 1, :, 1]]
x, h_state, c_state = nn.LSTM(self.lstm_units, return_sequences=True, unroll=True,
return_state=True, name="lstm2")(x, in_state_2)
states_h.append(h_state)
states_c.append(c_state)
x = nn.Dense(self.fdims, activation='sigmoid', name="dense1")(x)
out_states_h = tf.stack(states_h, axis=1)
out_states_c = tf.stack(states_c, axis=1)
states_out = tf.stack([out_states_h, out_states_c], axis=-1)
model_1 = tf.keras.Model(inputs=[mag_farend, states_in_1, mag_nearend], outputs=[x, states_out])
# model2
farend_in = tf.keras.Input(batch_shape=(1, 1, self.block_len), name='input_farend')
nearend_in = nn.Input(batch_shape=(1, 1, self.block_len), name='input_nearend')
states_in_2 = nn.Input(batch_shape=(1, 2, self.lstm_units, 2), name='input_states')
conv_layer = nn.Conv1D(512, kernel_size=1, strides=1, use_bias=False, name="conv1")
nearend_2 = conv_layer(nearend_in)
farend_2 = conv_layer(farend_in)
if norm_stft:
nearend_2_norm = InstantLayerNormalization(name="nearend2norm")(nearend_2)
farend_2_norm = InstantLayerNormalization(name='farend2norm')(farend_2)
else:
nearend_2_norm = nearend_2
farend_2_norm = farend_2
near_far = nn.Concatenate()([nearend_2_norm, farend_2_norm])
states_h, states_c = [], []
in_state_1 = [states_in_2[:, 0, :, 0], states_in_2[:, 0, :, 1]]
x, h_state, c_state = nn.LSTM(self.lstm_units, return_sequences=True, unroll=True,
return_state=True, name="lstm3")(near_far, in_state_1)
states_h.append(h_state)
states_c.append(c_state)
x = nn.Dropout(self.dropout)(x)
in_state_2 = [states_in_2[:, 1, :, 0], states_in_2[:, 1, :, 1]]
x, h_state, c_state = nn.LSTM(self.lstm_units, return_sequences=True, unroll=True,
return_state=True, name="lstm4")(x, in_state_2)
states_h.append(h_state)
states_c.append(c_state)
x = nn.Dense(512, activation='sigmoid', name='dense2')(x)
x = nn.Multiply()([nearend_2, x])
x = nn.Conv1D(512, kernel_size=1, use_bias=False, name='conv2')(x)
out_states_h = tf.stack(states_h, axis=1)
out_states_c = tf.stack(states_c, axis=1)
states_out = tf.stack([out_states_h, out_states_c], axis=-1)
model_2 = tf.keras.Model(inputs=[nearend_in, states_in_2, farend_in], outputs=[x, states_out])
# copy weights
weight_data = model.get_weights()
model_1.set_weights(weight_data[: num_elements_first_core])
model_2.set_weights(weight_data[num_elements_first_core:])
# convert to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model_1)
if use_dynamic_range_quant:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
else:
converter.target_spec.supported_types = [tf.float32]
tflite_model = converter.convert()
with tf.io.gfile.GFile(target_name + "_1.tflite", "wb") as f:
f.write(tflite_model)
converter = tf.lite.TFLiteConverter.from_keras_model(model_2)
if use_dynamic_range_quant:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
else:
converter.target_spec.supported_types = [tf.float32]
tflite_model = converter.convert()
with tf.io.gfile.GFile(target_name + "_2.tflite", "wb") as f:
f.write(tflite_model)
print("TF lite conversion done...")
`