breizhn / DTLN-aec

This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is this a mistake in the paper? or Am i misunderstood it ?

fhahaha opened this issue · comments

The input of the second core should be the original segment not the FFT result.

Screen Shot 2020-11-23 at 5 48 44 PM

Hi,

yes this is a mistake. It was already brought to my attention. The updated paper will be released on arXiv tomorrow.

commented

first of all, thanks for your project!!!
i write a dtln-aec model depend on your code, i train this model with the AEC-challenge dataset, train.py depend on your other project: https://github.com/breizhn/DTLN .
but when i convert my trained model to tflite and inference with run_aec.py, the output audio is silence.
where are the mistakes, waiting for your help!!!
`
def stftLayer(self, x):
frames = tf.signal.frame(x, self.block_len, self.block_shift)
stft_dat = tf.signal.rfft(frames)
mag = tf.abs(stft_dat)
phase = tf.math.angle(stft_dat)
return [frames, stft_dat, mag, phase]

def create_dtln_aec(self, batch=None, norm_stft=True, stateful=False):
    input_farend = tf.keras.Input(batch_shape=(batch, self.wav_dims), dtype=tf.float32, name="input_farend")
    input_nearend = tf.keras.Input(batch_shape=(batch, self.wav_dims), dtype=tf.float32, name="input_nearend")

    frame_farend, _, mag_farend, angle_farend = nn.Lambda(self.stftLayer)(input_farend)
    _, data_nearend, mag_nearend, angle_nearend = nn.Lambda(self.stftLayer)(input_nearend)
    mag_farend_sq = tf.square(mag_farend)
    mag_nearend_sq = tf.square(mag_nearend)
    if norm_stft:
        mag_farend_norm = InstantLayerNormalization(name='magfarendnorm')(tf.math.log(mag_farend_sq + 1e-7))
        mag_nearend_norm = InstantLayerNormalization(name='magnearendnorm')(tf.math.log(mag_nearend_sq + 1e-7))
    else:
        mag_farend_norm = mag_farend
        mag_nearend_norm = mag_nearend
    mag_norm = nn.Concatenate(axis=-1, name="cat1")([mag_farend_norm, mag_nearend_norm])
    x = nn.LSTM(self.lstm_units, return_sequences=True, stateful=stateful, name="lstm1")(mag_norm)
    x = nn.Dropout(self.dropout)(x)
    x = nn.LSTM(self.lstm_units, return_sequences=True, stateful=stateful, name="lstm2")(x)
    x = nn.Dense(self.fdims, activation='sigmoid', name="dense1")(x)
    stft_nearend_1 = nn.Multiply()([tf.cast(data_nearend, dtype=tf.float32), x])

    nearend_1 = nn.Lambda(self.ifftLayer)([stft_nearend_1, angle_nearend])

    conv_layer = nn.Conv1D(512, kernel_size=1, strides=1, use_bias=False, name="conv1")
    nearend_2 = conv_layer(nearend_1)
    farend_2 = conv_layer(frame_farend)
    if norm_stft:
        nearend_2_norm = InstantLayerNormalization(name="nearend2norm")(nearend_2)
        farend_2_norm = InstantLayerNormalization(name='farend2norm')(farend_2)
    else:
        nearend_2_norm = nearend_2
        farend_2_norm = farend_2
    near_far = nn.Concatenate()([nearend_2_norm, farend_2_norm])
    x = nn.LSTM(self.lstm_units, return_sequences=True, stateful=stateful, name="lstm3")(near_far)
    x = nn.Dropout(self.dropout)(x)
    x = nn.LSTM(self.lstm_units, return_sequences=True, stateful=stateful, name="lstm4")(x)
    x = nn.Dense(512, activation='sigmoid', name='dense2')(x)
    x = nn.Multiply()([nearend_2, x])
    x = nn.Conv1D(512, kernel_size=1, use_bias=False, name='conv2')(x)
    output = nn.Lambda(self.overlapAddLayer)(x)

    model = tf.keras.Model(inputs=[input_nearend, input_farend], outputs=output, name='dtln-aec')
    print(model.summary())
    return model

def create_dtln_aec_tflite(self, weights, norm_stft, target_name, use_dynamic_range_quant=False):
    if norm_stft:
        num_elements_first_core = 2 + 2 + 2*3 + 2
    else:
        num_elements_first_core = 2*3 + 2
    model = self.create_dtln_aec(batch=1, stateful=True)
    model.load_weights(weights)

    # model1
    mag_farend = nn.Input(batch_shape=(1, 1, self.block_len//2 + 1), name='input_mag_farend')
    mag_nearend = nn.Input(batch_shape=(1, 1, self.block_len//2 + 1), name='input_mag_nearend')
    states_in_1 = nn.Input(batch_shape=(1, 2, self.lstm_units, 2), name='input_states')

    mag_farend_sq = tf.square(mag_farend)
    mag_nearend_sq = tf.square(mag_nearend)
    if norm_stft:
        mag_farend_norm = InstantLayerNormalization(name='magfarendnorm')(tf.math.log(mag_farend_sq + 1e-7))
        mag_nearend_norm = InstantLayerNormalization(name='magnearendnorm')(tf.math.log(mag_nearend_sq + 1e-7))
    else:
        mag_farend_norm = mag_farend
        mag_nearend_norm = mag_nearend
    mag_norm = nn.Concatenate(axis=-1, name="cat1")([mag_farend_norm, mag_nearend_norm])

    states_h, states_c = [], []
    in_state_1 = [states_in_1[:, 0, :, 0], states_in_1[:, 0, :, 1]]
    x, h_state, c_state = nn.LSTM(self.lstm_units, return_sequences=True, unroll=True,
                                  return_state=True, name="lstm1")(mag_norm, in_state_1)
    states_h.append(h_state)
    states_c.append(c_state)
    x = nn.Dropout(self.dropout)(x)
    in_state_2 = [states_in_1[:, 1, :, 0], states_in_1[:, 1, :, 1]]
    x, h_state, c_state = nn.LSTM(self.lstm_units, return_sequences=True, unroll=True,
                                  return_state=True, name="lstm2")(x, in_state_2)
    states_h.append(h_state)
    states_c.append(c_state)
    x = nn.Dense(self.fdims, activation='sigmoid', name="dense1")(x)

    out_states_h = tf.stack(states_h, axis=1)
    out_states_c = tf.stack(states_c, axis=1)
    states_out = tf.stack([out_states_h, out_states_c], axis=-1)

    model_1 = tf.keras.Model(inputs=[mag_farend, states_in_1, mag_nearend], outputs=[x, states_out])

    # model2
    farend_in = tf.keras.Input(batch_shape=(1, 1, self.block_len), name='input_farend')
    nearend_in = nn.Input(batch_shape=(1, 1, self.block_len), name='input_nearend')
    states_in_2 = nn.Input(batch_shape=(1, 2, self.lstm_units, 2), name='input_states')

    conv_layer = nn.Conv1D(512, kernel_size=1, strides=1, use_bias=False, name="conv1")
    nearend_2 = conv_layer(nearend_in)
    farend_2 = conv_layer(farend_in)
    if norm_stft:
        nearend_2_norm = InstantLayerNormalization(name="nearend2norm")(nearend_2)
        farend_2_norm = InstantLayerNormalization(name='farend2norm')(farend_2)
    else:
        nearend_2_norm = nearend_2
        farend_2_norm = farend_2
    near_far = nn.Concatenate()([nearend_2_norm, farend_2_norm])
    states_h, states_c = [], []
    in_state_1 = [states_in_2[:, 0, :, 0], states_in_2[:, 0, :, 1]]
    x, h_state, c_state = nn.LSTM(self.lstm_units, return_sequences=True, unroll=True,
                                  return_state=True, name="lstm3")(near_far, in_state_1)
    states_h.append(h_state)
    states_c.append(c_state)
    x = nn.Dropout(self.dropout)(x)
    in_state_2 = [states_in_2[:, 1, :, 0], states_in_2[:, 1, :, 1]]
    x, h_state, c_state = nn.LSTM(self.lstm_units, return_sequences=True, unroll=True,
                                  return_state=True, name="lstm4")(x, in_state_2)
    states_h.append(h_state)
    states_c.append(c_state)
    x = nn.Dense(512, activation='sigmoid', name='dense2')(x)
    x = nn.Multiply()([nearend_2, x])
    x = nn.Conv1D(512, kernel_size=1, use_bias=False, name='conv2')(x)

    out_states_h = tf.stack(states_h, axis=1)
    out_states_c = tf.stack(states_c, axis=1)
    states_out = tf.stack([out_states_h, out_states_c], axis=-1)

    model_2 = tf.keras.Model(inputs=[nearend_in, states_in_2, farend_in], outputs=[x, states_out])

    # copy weights
    weight_data = model.get_weights()
    model_1.set_weights(weight_data[: num_elements_first_core])
    model_2.set_weights(weight_data[num_elements_first_core:])

    # convert to tflite
    converter = tf.lite.TFLiteConverter.from_keras_model(model_1)
    if use_dynamic_range_quant:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
    else:
        converter.target_spec.supported_types = [tf.float32]
    tflite_model = converter.convert()
    with tf.io.gfile.GFile(target_name + "_1.tflite", "wb") as f:
        f.write(tflite_model)

    converter = tf.lite.TFLiteConverter.from_keras_model(model_2)
    if use_dynamic_range_quant:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
    else:
        converter.target_spec.supported_types = [tf.float32]
    tflite_model = converter.convert()
    with tf.io.gfile.GFile(target_name + "_2.tflite", "wb") as f:
        f.write(tflite_model)
    print("TF lite conversion done...")

`