enhancer12 / TSPNN

Two-stage progressive neural network for acoustic echo cancellation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

请问这个模型会开源吗?

KollyYang opened this issue · comments

我实现了一个 参数总数和作者的有些不一样,我的只有1.26M 参数,@enhancer12


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from  scipy.signal import get_window

def init_kernels(win_len, fft_len, win_type=None, invers=False):
    if win_type == 'None' or win_type is None:
        window = np.ones(win_len)
    else:
        window = get_window(win_type, win_len, fftbins=True)  # **0.5

    N = fft_len
    fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
    real_kernel = np.real(fourier_basis)
    imag_kernel = np.imag(fourier_basis)
    kernel = np.concatenate([real_kernel, imag_kernel], 1).T

    if invers:
        kernel = np.linalg.pinv(kernel).T

    kernel = kernel * window
    kernel = kernel[:, None, :]
    return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))

class ConvSTFT(nn.Module):
    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
        super(ConvSTFT, self).__init__()
        if fft_len == None:
            self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
        else:
            self.fft_len = fft_len
        kernel, _ = init_kernels(win_len, self.fft_len, win_type)
        self.register_buffer('weight', kernel)
        self.feature_type = feature_type
        self.stride = win_inc
        self.win_len = win_len
        self.dim = self.fft_len

    def forward(self, inputs):
        if inputs.dim() == 2:
            inputs = torch.unsqueeze(inputs, 1)
        inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
        outputs = F.conv1d(inputs, self.weight, stride=self.stride)

        if self.feature_type == 'complex':
            return outputs
        else:
            dim = self.dim // 2 + 1
            real = outputs[:, :dim, :]
            imag = outputs[:, dim:, :]
            mags = torch.sqrt(real ** 2 + imag ** 2)
            phase = torch.atan2(imag, real)
            return mags, phase

class ConviSTFT(nn.Module):
    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
        super(ConviSTFT, self).__init__()
        if fft_len == None:
            self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
        else:
            self.fft_len = fft_len
        kernel, window = init_kernels(win_len, self.fft_len, win_type, invers=True)
        self.register_buffer('weight', kernel)
        self.feature_type = feature_type
        self.win_type = win_type
        self.win_len = win_len
        self.stride = win_inc
        self.dim = self.fft_len
        self.register_buffer('window', window)
        self.register_buffer('enframe', torch.eye(win_len)[:, None, :])

    def forward(self, inputs, phase=None):
        if phase is not None:
            real = inputs * torch.cos(phase)
            imag = inputs * torch.sin(phase)
            inputs = torch.cat([real, imag], 1)
        outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
        t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
        coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
        outputs = outputs / (coff + 1e-8)
        outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]
        return outputs

class CausalConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride):
        super(CausalConv, self).__init__()
        self.stride = stride
        self.kernel_size = kernel_size
        self.out_ch = out_ch
        self.in_ch = in_ch
        self.left_pad = kernel_size[1] - 1
        padding = (kernel_size[0] // 2, self.left_pad)
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size, stride=stride,
                              padding=padding)
        self.norm = nn.BatchNorm2d(out_ch)
        self.activation = nn.PReLU()                       

    def forward(self, x):
        B, C, F, T = x.size()
        out = self.conv(x)[..., :T]
        out = self.norm(out)
        out = self.activation(out)
        return  out

class CausalTransConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, output_padding = (0,0)):
        super(CausalTransConv, self).__init__()
        padding  = (kernel_size[0] // 2, 0)
        self.trans_conv = nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size,
                                             stride=stride, padding=padding, output_padding=output_padding)
        self.norm = nn.BatchNorm2d(out_ch)
        self.activation = nn.PReLU()                                     

    def forward(self, x):
        B, C, F, T = x.size()
        out = self.trans_conv(x)[..., :T]
        out = self.norm(out)
        out = self.activation(out)   
        return out

class DeepFilter(nn.Module):
    def __init__(self, F_neighbors, T_pastframes, T_lookahead):
        '''
        inputs would be [B,F,T]
        '''
        super(DeepFilter, self).__init__()
        self.F_neighbors  = F_neighbors
        self.T_pastframes = T_pastframes
        self.T_lookahead = T_lookahead
        assert(self.T_pastframes>=self.T_lookahead)
        t_width = self.T_pastframes + self.T_lookahead +1
        f_width = self.F_neighbors*2+1
        kernel = torch.eye(t_width*f_width)
        
        self.register_buffer('kernel', torch.reshape(kernel, [t_width*f_width, 1, f_width, t_width]))       
    def forward(self, inputs, filters):
        '''
            inputs is [real, imag]: [ [B,F,T], [B,F,T] ] 
            filters is [real, imag]: [ [B,F,T], [B,F,T] ] 
        '''
        chunked_inputs = torch.cat(inputs,0)[:,None]
        pad = (self.T_pastframes - self.T_lookahead, 0, 0, 0)
        conv_pad = nn.ConstantPad2d(pad, 0.0)
        chunked_inputs = conv_pad(chunked_inputs)
        chunked_inputs = F.conv2d(
                                    chunked_inputs,
                                    self.kernel, 
                                    padding= [self.F_neighbors, self.T_lookahead],
            )   
        inputs_r, inputs_i = torch.chunk(chunked_inputs, 2, 0)
        chunked_filters = torch.cat(filters,0)[:,None]
        chunked_filters = conv_pad(chunked_filters)
        chunked_filters = F.conv2d(
                                    chunked_filters,
                                    self.kernel, 
                                    padding= [self.F_neighbors, self.T_lookahead],
            )   
        filters_r, filters_i = torch.chunk(chunked_filters, 2, 0)
        outputs_r = inputs_r*filters_r - inputs_i*filters_i
        outputs_i = inputs_r*filters_i + inputs_r*filters_i
        outputs_r = torch.sum(outputs_r, 1)
        outputs_i = torch.sum(outputs_i, 1)
        return outputs_r, outputs_i

class DPRnn(torch.nn.Module):
    def __init__(self, hidden_ch_F, hidden_ch_T, F_dim, input_ch):
        super(DPRnn, self).__init__()
        self.F_dim = F_dim
        self.input_size = input_ch
        self.hidden_F = hidden_ch_F
        self.hidden_T = hidden_ch_T
        self.intra_rnn = torch.nn.GRU(input_size=self.input_size, hidden_size= self.hidden_F // 2, bidirectional=True,
                                       batch_first=True)
        self.intra_fc = torch.nn.Linear(in_features=self.hidden_F, out_features=self.input_size)
        self.intra_ln = torch.nn.LayerNorm([F_dim, self.input_size])

        self.inter_rnn = torch.nn.GRU(input_size=self.input_size, hidden_size=self.hidden_T, batch_first=True)
        self.inter_fc = torch.nn.Linear(in_features=self.hidden_T, out_features=self.input_size)
        self.inter_ln = torch.nn.LayerNorm([F_dim, self.input_size])

    def forward(self, x):
        """
        :param x: B,C,F,T
        :return:
        """
        B, C, F, T = x.size()
        x = x.permute(0, 3, 2, 1)  # B,T,F,C
        intra_in = torch.reshape(x, [B * T, F, C])
        intra_rnn_out, _ = self.intra_rnn(intra_in)
        intra_out = self.intra_ln(torch.reshape(self.intra_fc(intra_rnn_out), [B, T, F, C]))  # B,T,F,C
        intra_out = x + intra_out  # B,T,F,C
        inter_in = intra_out.permute(0, 2, 1, 3)  # B,F,T,C
        inter_in = torch.reshape(inter_in, [B * F, T, C])
        inter_rnn_out, _ = self.inter_rnn(inter_in)
        inter_out = self.inter_ln(
            torch.reshape(self.inter_fc(inter_rnn_out), [B, F, T, C]).permute(0, 2, 1, 3))  # B,T,F,C
        out = (intra_out + inter_out).permute(0, 3, 2, 1)
        return out

class FTGRU(torch.nn.Module):
    def __init__(self, hidden_ch_F, hidden_ch_T, input_ch):
        super(FTGRU, self).__init__()
        self.input_size = input_ch
        self.hidden_F = hidden_ch_F
        self.hidden_T = hidden_ch_T
        self.f_gru0 = torch.nn.GRU(input_size=self.input_size, hidden_size= self.hidden_F // 2, bidirectional=True,
                                       batch_first=True)
        self.t_gru0 = torch.nn.GRU(input_size=self.hidden_F, hidden_size=self.hidden_T, batch_first=True)

        self.norm0 = nn.BatchNorm2d(self.hidden_T)
        self.activation0 = nn.PReLU()

        self.f_gru1 = torch.nn.GRU(input_size=self.hidden_T, hidden_size= self.hidden_F // 2, bidirectional=True,
                                       batch_first=True)
        self.t_gru1 = torch.nn.GRU(input_size=self.hidden_F, hidden_size=input_ch, batch_first=True)

        self.norm1 = nn.BatchNorm2d(input_ch)
        self.activation1 = nn.PReLU()

    def forward(self, x):
        """
        :param x: B,C,F,T
        :return:
        """
        B, C, F, T = x.size()
        x = x.permute(0, 3, 2, 1)  # B,T,F,C
        input0 = torch.reshape(x, [B * T, F, C])
        f_gru0_out, _ = self.f_gru0(input0)
        f_gru0_out    =  torch.reshape(f_gru0_out, [B, T, F, -1])
        t_gru0_in  = f_gru0_out.permute(0, 2, 1, 3) 
        t_gru0_in  = torch.reshape(t_gru0_in, [B * F, T, -1])
        t_gru0_out, _ = self.t_gru0(t_gru0_in)
        t_gru0_out    =  torch.reshape(t_gru0_out, [B, F, T, -1])
        norm0_in  = t_gru0_out.permute(0, 3, 1, 2)
        norm0_out = self.norm0(norm0_in) 
        activation0_out = self.activation0(norm0_out) #B,C,F,T

        x = activation0_out.permute(0, 3, 2, 1)  # B,T,F,C
        input1 = torch.reshape(x, [B * T, F, -1])
        f_gru1_out, _ = self.f_gru1(input1)
        f_gru1_out    =  torch.reshape(f_gru1_out, [B, T, F, -1])
        t_gru1_in  = f_gru1_out.permute(0, 2, 1, 3) 
        t_gru1_in  = torch.reshape(t_gru1_in, [B * F, T, -1])
        t_gru1_out, _ = self.t_gru1(t_gru1_in)
        t_gru1_out    =  torch.reshape(t_gru1_out, [B, F, T, C])
        norm1_in = t_gru1_out.permute(0, 3, 1, 2)
        norm1_out = self.norm1(norm1_in) 
        activation1_out = self.activation1(norm1_out)
        return activation1_out

class VAD(nn.Module):
    def __init__(self):
        super(VAD, self).__init__()
        self.conv0     = CausalConv(32,16, (1,1),(1,1))
        self.f_gru     = torch.nn.GRU(input_size=16, hidden_size= 8, bidirectional=True, batch_first=True)
        self.conv1d_0 = nn.Conv1d(16,16,kernel_size=1,stride=1, bias=False)
        self.norm_0 = nn.BatchNorm1d(16)
        self.activation_0 = nn.PReLU()      
        self.conv1d_1 = nn.Conv1d(16,2,kernel_size=1,stride=1, bias=False)

    def forward(self, x):
        conv0_out = self.conv0(x)
        B, C, F, T = conv0_out.size()
        conv0_out = conv0_out.permute(0, 3, 2, 1)  # B,T,F,C
        intra_in = torch.reshape(conv0_out, [B * T, F, C])   
        _, f_gru_h = self.f_gru(intra_in)
        h_0, _, h_2 = f_gru_h.size()
        f_gru_h =  torch.reshape(f_gru_h.permute(1,0,2),[B,T,h_0,h_2])
        f_gru_h = torch.reshape(f_gru_h,[B,T,-1]).permute(0,2,1)
        conv1d_0_out = self.conv1d_0(f_gru_h) 
        norm_out = self.norm_0(conv1d_0_out)
        activation_out = self.activation_0(norm_out)
        conv1d_1_out = self.conv1d_1(activation_out).permute(0,2,1) #B,T,2 
        return conv1d_1_out

class TSPNN(nn.Module):
    def __init__(self):
        super(TSPNN, self).__init__()
        self.stft  = ConvSTFT(320, 160, 320,'hann', 'complex')
        self.istft = ConviSTFT(320, 160, 320,'hann', 'complex')

        self.coarse_encoder_conv0 = CausalConv(2,16, (5,1),(1,1))
        self.coarse_encoder_conv1 = CausalConv(16,16,(1,5),(1,1))
        self.coarse_encoder_conv2 = CausalConv(16,16,(6,5),(2,1))
        self.coarse_encoder_conv3 = CausalConv(16,32,(4,3),(2,1))
        self.coarse_encoder_conv4 = CausalConv(32,32,(6,5),(2,1))
        self.coarse_encoder_conv5 = CausalConv(32,32,(5,3),(2,1))
        self.coarse_encoder_conv6 = CausalConv(32,32,(3,5),(2,1))
        self.coarse_encoder_conv7 = CausalConv(32,32,(3,3),(1,1))
        self.ftgru_coarse = FTGRU(64,64,32)
        self.vad = VAD()
        self.coarse_decoder_conv0 = CausalTransConv(32,32, (3,3),(1,1))
        self.coarse_decoder_conv1 = CausalTransConv(32,32, (3,5),(2,1))
        self.coarse_decoder_conv2 = CausalTransConv(32,32, (5,3),(2,1))   
        self.coarse_decoder_conv3 = CausalTransConv(32,32, (6,5),(2,1),(1,0))  
        self.coarse_decoder_conv4 = CausalTransConv(32,16, (4,3),(2,1),(1,0))  
        self.coarse_decoder_conv5 = CausalTransConv(16,16, (6,5),(2,1),(1,0))    
        self.coarse_decoder_conv6 = CausalTransConv(16,16, (1,5),(1,1))
        self.coarse_decoder_conv7 = CausalTransConv(16,2,  (5,1),(1,1))

        self.coarse_decoder_gate_conv0 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv1 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv2 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv3 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv4 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv5 = nn.Conv2d(32,16, (1,1),(1,1))
        self.coarse_decoder_gate_conv6 = nn.Conv2d(32,16, (1,1),(1,1))
        self.coarse_decoder_gate_conv7 = nn.Conv2d(32,16, (1,1),(1,1))
        self.coarse_dense = nn.Linear(in_features=161*2, out_features=161*2)

        self.fine_encoder_conv0   = CausalConv(3,16, (5,1),(1,1))
        self.fine_encoder_conv1   = CausalConv(16,16,(1,5),(1,1))
        self.fine_encoder_conv2   = CausalConv(16,32,(6,5),(2,1))
        self.fine_encoder_conv3   = CausalConv(32,32,(4,3),(2,1))
        self.fine_encoder_conv4   = CausalConv(32,64,(6,5),(2,1))
        self.fine_encoder_conv5   = CausalConv(64,64,(5,3),(2,1))
        self.fine_encoder_conv6   = CausalConv(64,64,(3,5),(2,1))
        self.fine_encoder_conv7   = CausalConv(64,64,(3,3),(1,1))

        self.ftgru_fine = FTGRU(128,128,64)

        self.fine_decoder_conv0 = CausalTransConv(64,64, (3,3),(1,1))
        self.fine_decoder_conv1 = CausalTransConv(64,64, (3,5),(2,1))
        self.fine_decoder_conv2 = CausalTransConv(64,64, (5,3),(2,1))
        self.fine_decoder_conv3 = CausalTransConv(64,32, (6,5),(2,1),(1,0))
        self.fine_decoder_conv4 = CausalTransConv(32,32, (4,3),(2,1),(1,0))
        self.fine_decoder_conv5 = CausalTransConv(32,16, (6,5),(2,1),(1,0))
        self.fine_decoder_conv6 = CausalTransConv(16,16, (1,5),(1,1))
        self.fine_decoder_conv7 = CausalTransConv(16,2,  (5,1),(1,1))

        self.fine_decoder_gate_conv0 = nn.Conv2d(128,64, (1,1),(1,1))
        self.fine_decoder_gate_conv1 = nn.Conv2d(128,64, (1,1),(1,1))
        self.fine_decoder_gate_conv2 = nn.Conv2d(128,64, (1,1),(1,1))  
        self.fine_decoder_gate_conv3 = nn.Conv2d(128,64, (1,1),(1,1)) 
        self.fine_decoder_gate_conv4 = nn.Conv2d(64,32,  (1,1),(1,1)) 
        self.fine_decoder_gate_conv5 = nn.Conv2d(64,32,  (1,1),(1,1))
        self.fine_decoder_gate_conv6 = nn.Conv2d(32,16,  (1,1),(1,1))
        self.fine_decoder_gate_conv7 = nn.Conv2d(32,16,  (1,1),(1,1))
        self.fine_dense = nn.Linear(in_features=161*2, out_features=161*2)
        self.df = DeepFilter(3,3,1)

    def forward(self, mic, ref):
        return self.input_forward(mic, ref)

    def input_forward(self, mic, ref):
        stft_mic = self.stft(mic)
        real_mic = stft_mic[:, :161]
        imag_mic = stft_mic[:, 161:]
        spec_mags_mic = torch.sqrt(real_mic ** 2 + imag_mic ** 2 + 1e-8)
        compressed_mags_mic = torch.pow(spec_mags_mic, 0.3)

        stft_ref = self.stft(ref)
        real_ref = stft_ref[:, :161]
        imag_ref = stft_ref[:, 161:]
        compressed_mags_ref = torch.pow(real_ref**2 + imag_ref**2 + 1e-8, 0.3*0.5)


        coarse_spec_mags = torch.stack([compressed_mags_ref, compressed_mags_mic], dim=1)  #(B,2,161,T)
        coarse_encoder_conv_0  = self.coarse_encoder_conv0(coarse_spec_mags)               #(B,16,161,T)
        coarse_encoder_conv_1  = self.coarse_encoder_conv1(coarse_encoder_conv_0)   #(B,16,161,T)
        coarse_encoder_conv_2  = self.coarse_encoder_conv2(coarse_encoder_conv_1)   #(B,16,81,T)
        coarse_encoder_conv_3  = self.coarse_encoder_conv3(coarse_encoder_conv_2)   #(B,32,41,T)
        coarse_encoder_conv_4  = self.coarse_encoder_conv4(coarse_encoder_conv_3)   #(B,32,21,T)
        coarse_encoder_conv_5  = self.coarse_encoder_conv5(coarse_encoder_conv_4)   #(B,32,11,T)
        coarse_encoder_conv_6  = self.coarse_encoder_conv6(coarse_encoder_conv_5)   #(B,32,6,T)
        coarse_encoder_conv_7  = self.coarse_encoder_conv7(coarse_encoder_conv_6)   #(B,32,6,T)

        ftgru_coarse_out = self.ftgru_coarse(coarse_encoder_conv_7)

        vad_out = self.vad(ftgru_coarse_out)

        coarse_decoder_gate_conv_0 = torch.cat([ftgru_coarse_out, coarse_encoder_conv_7] , 1)
        coarse_decoder_gate_conv_0 = torch.tanh(self.coarse_decoder_gate_conv0(coarse_decoder_gate_conv_0))
        coarse_decoder_conv_0 = coarse_decoder_gate_conv_0 * ftgru_coarse_out
        coarse_decoder_conv_0 = self.coarse_decoder_conv0(coarse_decoder_conv_0)

        coarse_decoder_gate_conv_1 = torch.cat([coarse_decoder_conv_0, coarse_encoder_conv_6] , 1)
        coarse_decoder_gate_conv_1 = torch.tanh(self.coarse_decoder_gate_conv1(coarse_decoder_gate_conv_1))
        coarse_decoder_conv_1 = coarse_decoder_gate_conv_1 * coarse_decoder_conv_0
        coarse_decoder_conv_1 = self.coarse_decoder_conv1(coarse_decoder_conv_1)

        coarse_decoder_gate_conv_2 = torch.cat([coarse_decoder_conv_1, coarse_encoder_conv_5] , 1)
        coarse_decoder_gate_conv_2 = torch.tanh(self.coarse_decoder_gate_conv2(coarse_decoder_gate_conv_2))
        coarse_decoder_conv_2 = coarse_decoder_gate_conv_2 * coarse_decoder_conv_1
        coarse_decoder_conv_2 = self.coarse_decoder_conv2(coarse_decoder_conv_2)

        coarse_decoder_gate_conv_3 = torch.cat([coarse_decoder_conv_2, coarse_encoder_conv_4] , 1)
        coarse_decoder_gate_conv_3 = torch.tanh(self.coarse_decoder_gate_conv3(coarse_decoder_gate_conv_3))
        coarse_decoder_conv_3 = coarse_decoder_gate_conv_3 * coarse_decoder_conv_2
        coarse_decoder_conv_3 = self.coarse_decoder_conv3(coarse_decoder_conv_3)

        coarse_decoder_gate_conv_4 = torch.cat([coarse_decoder_conv_3, coarse_encoder_conv_3] , 1)
        coarse_decoder_gate_conv_4 = torch.tanh(self.coarse_decoder_gate_conv4(coarse_decoder_gate_conv_4))
        coarse_decoder_conv_4 = coarse_decoder_gate_conv_4 * coarse_decoder_conv_3
        coarse_decoder_conv_4 = self.coarse_decoder_conv4(coarse_decoder_conv_4)

        coarse_decoder_gate_conv_5 = torch.cat([coarse_decoder_conv_4, coarse_encoder_conv_2] , 1)
        coarse_decoder_gate_conv_5 = torch.tanh(self.coarse_decoder_gate_conv5(coarse_decoder_gate_conv_5))
        coarse_decoder_conv_5 = coarse_decoder_gate_conv_5 * coarse_decoder_conv_4
        coarse_decoder_conv_5 = self.coarse_decoder_conv5(coarse_decoder_conv_5)


        coarse_decoder_gate_conv_6 = torch.cat([coarse_decoder_conv_5, coarse_encoder_conv_1] , 1)
        coarse_decoder_gate_conv_6 = torch.tanh(self.coarse_decoder_gate_conv6(coarse_decoder_gate_conv_6))
        coarse_decoder_conv_6 = coarse_decoder_gate_conv_6 * coarse_decoder_conv_5
        coarse_decoder_conv_6 = self.coarse_decoder_conv6(coarse_decoder_conv_6)

        coarse_decoder_gate_conv_7 = torch.cat([coarse_decoder_conv_6, coarse_encoder_conv_0] , 1)
        coarse_decoder_gate_conv_7 = torch.tanh(self.coarse_decoder_gate_conv7(coarse_decoder_gate_conv_7))
        coarse_decoder_conv_7 = coarse_decoder_gate_conv_7 * coarse_decoder_conv_6
        coarse_decoder_conv_7 = self.coarse_decoder_conv7(coarse_decoder_conv_7)

        coarse_mask_out = coarse_decoder_conv_7.permute(0,3,1,2)
        B,T,C,D = coarse_mask_out.size()
        coarse_mask_out = torch.reshape(coarse_mask_out,[B,T, -1])
        coarse_mask_out = torch.sigmoid(self.coarse_dense(coarse_mask_out))
        coarse_mask_out = coarse_mask_out.permute(0,2,1)
        real_coarse_mask_out = coarse_mask_out[:, :161]
        imag_coarse_mask_out = coarse_mask_out[:, 161:]

        coarse_enhanced_real = real_coarse_mask_out * real_mic - imag_coarse_mask_out * imag_mic
        coarse_enhanced_imag = real_coarse_mask_out * imag_mic + imag_coarse_mask_out * real_mic
        spec_mags_coarse_out   = torch.sqrt(coarse_enhanced_real ** 2 + coarse_enhanced_imag ** 2 + 1e-8)
        compressed_coarse_mags = torch.pow(spec_mags_coarse_out, 0.3)
        fine_spec_mags = torch.stack([compressed_mags_ref, compressed_coarse_mags, compressed_mags_mic], dim=1)

        fine_encoder_conv_0  = self.fine_encoder_conv0(fine_spec_mags)              
        fine_encoder_conv_1  = self.fine_encoder_conv1(fine_encoder_conv_0)   
        fine_encoder_conv_2  = self.fine_encoder_conv2(fine_encoder_conv_1)   
        fine_encoder_conv_3  = self.fine_encoder_conv3(fine_encoder_conv_2)   
        fine_encoder_conv_4  = self.fine_encoder_conv4(fine_encoder_conv_3)   
        fine_encoder_conv_5  = self.fine_encoder_conv5(fine_encoder_conv_4)   
        fine_encoder_conv_6  = self.fine_encoder_conv6(fine_encoder_conv_5)   
        fine_encoder_conv_7  = self.fine_encoder_conv7(fine_encoder_conv_6) 

        ftgru_fine_out = self.ftgru_fine(fine_encoder_conv_7)    

        fine_decoder_gate_conv_0 = torch.cat([ftgru_fine_out, fine_encoder_conv_7] , 1)
        fine_decoder_gate_conv_0 = torch.tanh(self.fine_decoder_gate_conv0(fine_decoder_gate_conv_0))
        fine_decoder_conv_0 = fine_decoder_gate_conv_0 * ftgru_fine_out
        fine_decoder_conv_0 = self.fine_decoder_conv0(fine_decoder_conv_0)

        fine_decoder_gate_conv_1 = torch.cat([fine_decoder_conv_0, fine_encoder_conv_6] , 1)
        fine_decoder_gate_conv_1 = torch.tanh(self.fine_decoder_gate_conv1(fine_decoder_gate_conv_1))
        fine_decoder_conv_1 = fine_decoder_gate_conv_1 * fine_decoder_conv_0
        fine_decoder_conv_1 = self.fine_decoder_conv1(fine_decoder_conv_1)

        fine_decoder_gate_conv_2 = torch.cat([fine_decoder_conv_1, fine_encoder_conv_5] , 1)
        fine_decoder_gate_conv_2 = torch.tanh(self.fine_decoder_gate_conv2(fine_decoder_gate_conv_2))
        fine_decoder_conv_2 = fine_decoder_gate_conv_2 * fine_decoder_conv_1
        fine_decoder_conv_2 = self.fine_decoder_conv2(fine_decoder_conv_2)

        fine_decoder_gate_conv_3 = torch.cat([fine_decoder_conv_2, fine_encoder_conv_4] , 1)
        fine_decoder_gate_conv_3 = torch.tanh(self.fine_decoder_gate_conv3(fine_decoder_gate_conv_3))
        fine_decoder_conv_3 = fine_decoder_gate_conv_3 * fine_decoder_conv_2
        fine_decoder_conv_3 = self.fine_decoder_conv3(fine_decoder_conv_3)

        fine_decoder_gate_conv_4 = torch.cat([fine_decoder_conv_3, fine_encoder_conv_3] , 1)
        fine_decoder_gate_conv_4 = torch.tanh(self.fine_decoder_gate_conv4(fine_decoder_gate_conv_4))
        fine_decoder_conv_4 = fine_decoder_gate_conv_4 * fine_decoder_conv_3
        fine_decoder_conv_4 = self.fine_decoder_conv4(fine_decoder_conv_4)

        fine_decoder_gate_conv_5 = torch.cat([fine_decoder_conv_4, fine_encoder_conv_2] , 1)
        fine_decoder_gate_conv_5 = torch.tanh(self.fine_decoder_gate_conv5(fine_decoder_gate_conv_5))
        fine_decoder_conv_5 = fine_decoder_gate_conv_5 * fine_decoder_conv_4
        fine_decoder_conv_5 = self.fine_decoder_conv5(fine_decoder_conv_5)

        fine_decoder_gate_conv_6 = torch.cat([fine_decoder_conv_5, fine_encoder_conv_1] , 1)
        fine_decoder_gate_conv_6 = torch.tanh(self.fine_decoder_gate_conv6(fine_decoder_gate_conv_6))
        fine_decoder_conv_6 = fine_decoder_gate_conv_6 * fine_decoder_conv_5
        fine_decoder_conv_6 = self.fine_decoder_conv6(fine_decoder_conv_6)

        fine_decoder_gate_conv_7 = torch.cat([fine_decoder_conv_6, fine_encoder_conv_0] , 1)
        fine_decoder_gate_conv_7 = torch.tanh(self.fine_decoder_gate_conv7(fine_decoder_gate_conv_7))
        fine_decoder_conv_7 = fine_decoder_gate_conv_7 * fine_decoder_conv_6
        fine_decoder_conv_7 = self.fine_decoder_conv7(fine_decoder_conv_7)

        fine_mask_out = fine_decoder_conv_7.permute(0,3,1,2)
        B,T,C,D = fine_mask_out.size()
        fine_mask_out = torch.reshape(fine_mask_out,[B,T, -1])
        fine_mask_out = torch.sigmoid(self.fine_dense(fine_mask_out))
        fine_mask_out = fine_mask_out.permute(0,2,1)
        real_fine_mask_out = fine_mask_out[:, :161]
        imag_fine_mask_out = fine_mask_out[:, 161:]  
        df_inputs = [coarse_enhanced_real,coarse_enhanced_imag]
        df_mask   = [real_fine_mask_out,   imag_fine_mask_out]
        fine_enhanced_real, fine_enhanced_imag = self.df(df_inputs,df_mask)  
        return coarse_enhanced_real, coarse_enhanced_imag, fine_enhanced_real, fine_enhanced_imag, vad_out 

    def compute_loss(self, coarse_enhanced_real, coarse_enhanced_imag, fine_enhanced_real, fine_enhanced_imag, pref_vad, clean_real, clean_imag, vad_label):
        clean_spec_mags   = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-8)
        vad_ce = F.cross_entropy(pref_vad, vad_label.long())
        coarse_real_mae  = F.l1_loss(coarse_enhanced_real, clean_real)
        coarse_imag_mae  = F.l1_loss(coarse_enhanced_imag, clean_imag)
        coarse_spec_mags = torch.sqrt(coarse_enhanced_real ** 2 + coarse_enhanced_imag ** 2 + 1e-8)
        coarse_spec_mae  = F.l1_loss(coarse_spec_mags, clean_spec_mags)
        coarse_loss = coarse_real_mae + coarse_imag_mae + coarse_spec_mae
        fine_real_mae  = F.l1_loss(fine_enhanced_real, clean_real)
        fine_imag_mae  = F.l1_loss(fine_enhanced_imag, clean_imag)
        fine_spec_mags = torch.sqrt(fine_enhanced_real ** 2 + fine_enhanced_imag ** 2 + 1e-8)
        fine_spec_mae  = F.l1_loss(fine_spec_mags, clean_spec_mags)
        fine_loss = fine_real_mae + fine_imag_mae + fine_spec_mae
        loss = 0,3*coarse_loss + 0.7*fine_loss + 0.06*vad_ce
        return loss

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

def print_parameter(model):
    model_state_dict = model.state_dict()
    for name, tensor in model_state_dict.items():
        print('{}\t{}\t{}'.format(name, tensor.shape, tensor.numel())) 

if __name__ == '__main__':
    time_len = 16000*2
    batch    = 10
    clean    = torch.randn(batch, time_len)
    mic      = torch.randn(batch, time_len)
    ref      = torch.randn(batch, time_len)
    model    = TSPNN()
    result = get_parameter_number(model)
    # print_parameter(model)
    print('Number of parameter: \n\t total: {:.2f} M, '
          'trainable: {:.2f} M'.format(result['Total'] / 1e6, result['Trainable'] / 1e6))
    coarse_real, coarse_imag, fine_real, fine_imag, pred_vad = model(mic,ref)
    stft  = ConvSTFT(320, 160, 320,'hann', 'complex')
    stft_clean = stft(clean)
    real_clean = stft_clean[:, :161]
    imag_clean = stft_clean[:, 161:]
    B, _, T = real_clean.size()
    pred_vad = torch.reshape(pred_vad,[B*T,2])
    vad_label   = torch.reshape(torch.randint(low=0, high=2, size=(B, T)),[B*T,1]).squeeze(1)
    loss = model.compute_loss(coarse_real, coarse_imag, fine_real, fine_imag, pred_vad,real_clean,imag_clean,vad_label)
    print('Hello world!')


commented

楼上实现基本正确,fine stage的mask大小和df部分有点问题

确实,deepfilter有点问题,跑出来的输出不对,就像改变采样率了一样

@c8x1 @shenbuguanni
非常感谢指点,我跑出来的效果也不太好,能帮忙修改一下吗?对这个模型确实非常感兴趣。

commented

@c8x1 @shenbuguanni 非常感谢指点,我跑出来的效果也不太好,能帮忙修改一下吗?对这个模型确实非常感兴趣。

@zuowanbushiwo 你的问题在于filter部分是不用做重复平移操作的。文章里说了,fine-stage的mask输出维度是2(complex) * F * T * t_widthf_width。 所以filter size应该是 B t_widthf_width F T,这个改完结果就能正确了。

但这个模型的最大问题是coarse-stage并没有那么coarse,因为训练过程没啥干预操作。结果上看stage1 3成的算力干了99%的活,stage2做的事很少,只是让谐波稍微清晰,频点间回声稍微小了一丢丢罢了,我自己训出来的结果,包括看作者提供的结果都是类似的现象。改进点可能是把loss和训练目标做复杂点吧。其次,就实际使用来说抛弃linear filter也有利有弊。

@c8x1 @shenbuguanni 麻烦帮忙看一下 这个对吗?参数对上了,有1.27M , 主要修改了 self.fine_decoder_conv7 的outchannel, 和 DeepFilter 部分 同时在DeepFilter 前加了 chanel_dense 和 freq_dense , 如果不对请 帮忙修改一下。谢谢!


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from  scipy.signal import get_window

def init_kernels(win_len, fft_len, win_type=None, invers=False):
    if win_type == 'None' or win_type is None:
        window = np.ones(win_len)
    else:
        window = get_window(win_type, win_len, fftbins=True)  # **0.5

    N = fft_len
    fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
    real_kernel = np.real(fourier_basis)
    imag_kernel = np.imag(fourier_basis)
    kernel = np.concatenate([real_kernel, imag_kernel], 1).T

    if invers:
        kernel = np.linalg.pinv(kernel).T

    kernel = kernel * window
    kernel = kernel[:, None, :]
    return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))

class ConvSTFT(nn.Module):
    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
        super(ConvSTFT, self).__init__()
        if fft_len == None:
            self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
        else:
            self.fft_len = fft_len
        kernel, _ = init_kernels(win_len, self.fft_len, win_type)
        self.register_buffer('weight', kernel)
        self.feature_type = feature_type
        self.stride = win_inc
        self.win_len = win_len
        self.dim = self.fft_len

    def forward(self, inputs):
        if inputs.dim() == 2:
            inputs = torch.unsqueeze(inputs, 1)
        inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
        outputs = F.conv1d(inputs, self.weight, stride=self.stride)

        if self.feature_type == 'complex':
            return outputs
        else:
            dim = self.dim // 2 + 1
            real = outputs[:, :dim, :]
            imag = outputs[:, dim:, :]
            mags = torch.sqrt(real ** 2 + imag ** 2)
            phase = torch.atan2(imag, real)
            return mags, phase

class ConviSTFT(nn.Module):
    def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
        super(ConviSTFT, self).__init__()
        if fft_len == None:
            self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
        else:
            self.fft_len = fft_len
        kernel, window = init_kernels(win_len, self.fft_len, win_type, invers=True)
        self.register_buffer('weight', kernel)
        self.feature_type = feature_type
        self.win_type = win_type
        self.win_len = win_len
        self.stride = win_inc
        self.dim = self.fft_len
        self.register_buffer('window', window)
        self.register_buffer('enframe', torch.eye(win_len)[:, None, :])

    def forward(self, inputs, phase=None):
        if phase is not None:
            real = inputs * torch.cos(phase)
            imag = inputs * torch.sin(phase)
            inputs = torch.cat([real, imag], 1)
        outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
        t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
        coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
        outputs = outputs / (coff + 1e-8)
        outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]
        return outputs

class CausalConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride):
        super(CausalConv, self).__init__()
        self.stride = stride
        self.kernel_size = kernel_size
        self.out_ch = out_ch
        self.in_ch = in_ch
        self.left_pad = kernel_size[1] - 1
        padding = (kernel_size[0] // 2, self.left_pad)
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size, stride=stride,
                              padding=padding)
        self.norm = nn.BatchNorm2d(out_ch)
        self.activation = nn.PReLU()                       

    def forward(self, x):
        B, C, F, T = x.size()
        out = self.conv(x)[..., :T]
        out = self.norm(out)
        out = self.activation(out)
        return  out

class CausalTransConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, output_padding = (0,0)):
        super(CausalTransConv, self).__init__()
        padding  = (kernel_size[0] // 2, 0)
        self.trans_conv = nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size,
                                             stride=stride, padding=padding, output_padding=output_padding)
        self.norm = nn.BatchNorm2d(out_ch)
        self.activation = nn.PReLU()                                     

    def forward(self, x):
        B, C, F, T = x.size()
        out = self.trans_conv(x)[..., :T]
        out = self.norm(out)
        out = self.activation(out)   
        return out

class DeepFilter(nn.Module):
    def __init__(self, F_neighbors, T_pastframes, T_lookahead):
        super(DeepFilter, self).__init__()
        self.F_neighbors  = F_neighbors
        self.T_pastframes = T_pastframes
        self.T_lookahead = T_lookahead
        assert(self.T_pastframes>=self.T_lookahead)
        t_width = self.T_pastframes + self.T_lookahead +1
        f_width = self.F_neighbors*2+1
        kernel = torch.eye(t_width*f_width)
        self.register_buffer('kernel', torch.reshape(kernel, [t_width*f_width, 1, f_width, t_width])) 

    def forward(self, inputs, filters_r, filters_i):
        '''
            inputs is [real, imag]: [ [B,F,T], [B,F,T] ] 
        '''
        chunked_inputs = torch.cat(inputs,0)[:,None]
        pad = (self.T_pastframes - self.T_lookahead, 0, 0, 0)
        conv_pad = nn.ConstantPad2d(pad, 0.0)
        chunked_inputs = conv_pad(chunked_inputs)
        chunked_inputs = F.conv2d(
                                    chunked_inputs,
                                    self.kernel, 
                                    padding= [self.F_neighbors, self.T_lookahead],
            )   
        inputs_r, inputs_i = torch.chunk(chunked_inputs, 2, 0)
        outputs_r = inputs_r*filters_r - inputs_i*filters_i
        outputs_i = inputs_r*filters_i + inputs_r*filters_i
        outputs_r = torch.sum(outputs_r, 1)
        outputs_i = torch.sum(outputs_i, 1)
        return outputs_r, outputs_i


class FTGRU(torch.nn.Module):
    def __init__(self, hidden_ch_F, hidden_ch_T, input_ch):
        super(FTGRU, self).__init__()
        self.input_size = input_ch
        self.hidden_F = hidden_ch_F
        self.hidden_T = hidden_ch_T
        self.f_gru0 = torch.nn.GRU(input_size=self.input_size, hidden_size= self.hidden_F // 2, bidirectional=True,
                                       batch_first=True)
        self.t_gru0 = torch.nn.GRU(input_size=self.hidden_F, hidden_size=self.hidden_T, batch_first=True)

        self.norm0 = nn.BatchNorm2d(self.hidden_T)
        self.activation0 = nn.PReLU()

        self.f_gru1 = torch.nn.GRU(input_size=self.hidden_T, hidden_size= self.hidden_F // 2, bidirectional=True,
                                       batch_first=True)
        self.t_gru1 = torch.nn.GRU(input_size=self.hidden_F, hidden_size=input_ch, batch_first=True)

        self.norm1 = nn.BatchNorm2d(input_ch)
        self.activation1 = nn.PReLU()

    def forward(self, x):
        """
        :param x: B,C,F,T
        :return:
        """
        B, C, F, T = x.size()
        x = x.permute(0, 3, 2, 1)  # B,T,F,C
        input0 = torch.reshape(x, [B * T, F, C])
        f_gru0_out, _ = self.f_gru0(input0)
        f_gru0_out    =  torch.reshape(f_gru0_out, [B, T, F, -1])
        t_gru0_in  = f_gru0_out.permute(0, 2, 1, 3) 
        t_gru0_in  = torch.reshape(t_gru0_in, [B * F, T, -1])
        t_gru0_out, _ = self.t_gru0(t_gru0_in)
        t_gru0_out    =  torch.reshape(t_gru0_out, [B, F, T, -1])
        norm0_in  = t_gru0_out.permute(0, 3, 1, 2)
        norm0_out = self.norm0(norm0_in) 
        activation0_out = self.activation0(norm0_out) #B,C,F,T

        x = activation0_out.permute(0, 3, 2, 1)  # B,T,F,C
        input1 = torch.reshape(x, [B * T, F, -1])
        f_gru1_out, _ = self.f_gru1(input1)
        f_gru1_out    =  torch.reshape(f_gru1_out, [B, T, F, -1])
        t_gru1_in  = f_gru1_out.permute(0, 2, 1, 3) 
        t_gru1_in  = torch.reshape(t_gru1_in, [B * F, T, -1])
        t_gru1_out, _ = self.t_gru1(t_gru1_in)
        t_gru1_out    =  torch.reshape(t_gru1_out, [B, F, T, C])
        norm1_in = t_gru1_out.permute(0, 3, 1, 2)
        norm1_out = self.norm1(norm1_in) 
        activation1_out = self.activation1(norm1_out)
        return activation1_out

class VAD(nn.Module):
    def __init__(self):
        super(VAD, self).__init__()
        self.conv0     = CausalConv(32,16, (1,1),(1,1))
        self.f_gru     = torch.nn.GRU(input_size=16, hidden_size= 8, bidirectional=True, batch_first=True)
        self.conv1d_0 = nn.Conv1d(16,16,kernel_size=1,stride=1, bias=False)
        self.norm_0 = nn.BatchNorm1d(16)
        self.activation_0 = nn.PReLU()      
        self.conv1d_1 = nn.Conv1d(16,2,kernel_size=1,stride=1, bias=False)

    def forward(self, x):
        conv0_out = self.conv0(x)
        B, C, F, T = conv0_out.size()
        conv0_out = conv0_out.permute(0, 3, 2, 1)  # B,T,F,C
        intra_in = torch.reshape(conv0_out, [B * T, F, C])   
        _, f_gru_h = self.f_gru(intra_in)
        h_0, _, h_2 = f_gru_h.size()
        f_gru_h =  torch.reshape(f_gru_h.permute(1,0,2),[B,T,h_0,h_2])
        f_gru_h = torch.reshape(f_gru_h,[B,T,-1]).permute(0,2,1)
        conv1d_0_out = self.conv1d_0(f_gru_h) 
        norm_out = self.norm_0(conv1d_0_out)
        activation_out = self.activation_0(norm_out)
        conv1d_1_out = self.conv1d_1(activation_out).permute(0,2,1) #B,T,2 
        return conv1d_1_out

class TSPNN(nn.Module):
    def __init__(self):
        super(TSPNN, self).__init__()
        self.stft  = ConvSTFT(320, 160, 320,'hann', 'complex')
        self.istft = ConviSTFT(320, 160, 320,'hann', 'complex')

        self.coarse_encoder_conv0 = CausalConv(2,16, (5,1),(1,1))
        self.coarse_encoder_conv1 = CausalConv(16,16,(1,5),(1,1))
        self.coarse_encoder_conv2 = CausalConv(16,16,(6,5),(2,1))
        self.coarse_encoder_conv3 = CausalConv(16,32,(4,3),(2,1))
        self.coarse_encoder_conv4 = CausalConv(32,32,(6,5),(2,1))
        self.coarse_encoder_conv5 = CausalConv(32,32,(5,3),(2,1))
        self.coarse_encoder_conv6 = CausalConv(32,32,(3,5),(2,1))
        self.coarse_encoder_conv7 = CausalConv(32,32,(3,3),(1,1))
        self.ftgru_coarse = FTGRU(64,64,32)
        self.vad = VAD()
        self.coarse_decoder_conv0 = CausalTransConv(32,32, (3,3),(1,1))
        self.coarse_decoder_conv1 = CausalTransConv(32,32, (3,5),(2,1))
        self.coarse_decoder_conv2 = CausalTransConv(32,32, (5,3),(2,1))   
        self.coarse_decoder_conv3 = CausalTransConv(32,32, (6,5),(2,1),(1,0))  
        self.coarse_decoder_conv4 = CausalTransConv(32,16, (4,3),(2,1),(1,0))  
        self.coarse_decoder_conv5 = CausalTransConv(16,16, (6,5),(2,1),(1,0))    
        self.coarse_decoder_conv6 = CausalTransConv(16,16, (1,5),(1,1))
        self.coarse_decoder_conv7 = CausalTransConv(16,2,  (5,1),(1,1))

        self.coarse_decoder_gate_conv0 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv1 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv2 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv3 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv4 = nn.Conv2d(64,32, (1,1),(1,1))
        self.coarse_decoder_gate_conv5 = nn.Conv2d(32,16, (1,1),(1,1))
        self.coarse_decoder_gate_conv6 = nn.Conv2d(32,16, (1,1),(1,1))
        self.coarse_decoder_gate_conv7 = nn.Conv2d(32,16, (1,1),(1,1))
        self.coarse_dense = nn.Linear(in_features=161*2, out_features=161*2)

        self.fine_encoder_conv0   = CausalConv(3,16, (5,1),(1,1))
        self.fine_encoder_conv1   = CausalConv(16,16,(1,5),(1,1))
        self.fine_encoder_conv2   = CausalConv(16,32,(6,5),(2,1))
        self.fine_encoder_conv3   = CausalConv(32,32,(4,3),(2,1))
        self.fine_encoder_conv4   = CausalConv(32,64,(6,5),(2,1))
        self.fine_encoder_conv5   = CausalConv(64,64,(5,3),(2,1))
        self.fine_encoder_conv6   = CausalConv(64,64,(3,5),(2,1))
        self.fine_encoder_conv7   = CausalConv(64,64,(3,3),(1,1))

        self.ftgru_fine = FTGRU(128,128,64)

        self.fine_decoder_conv0 = CausalTransConv(64,64, (3,3),(1,1))
        self.fine_decoder_conv1 = CausalTransConv(64,64, (3,5),(2,1))
        self.fine_decoder_conv2 = CausalTransConv(64,64, (5,3),(2,1))
        self.fine_decoder_conv3 = CausalTransConv(64,32, (6,5),(2,1),(1,0))
        self.fine_decoder_conv4 = CausalTransConv(32,32, (4,3),(2,1),(1,0))
        self.fine_decoder_conv5 = CausalTransConv(32,16, (6,5),(2,1),(1,0))
        self.fine_decoder_conv6 = CausalTransConv(16,16, (1,5),(1,1))
        self.fine_decoder_conv7 = CausalTransConv(16,70,  (5,1),(1,1))

        self.fine_decoder_gate_conv0 = nn.Conv2d(128,64, (1,1),(1,1))
        self.fine_decoder_gate_conv1 = nn.Conv2d(128,64, (1,1),(1,1))
        self.fine_decoder_gate_conv2 = nn.Conv2d(128,64, (1,1),(1,1))  
        self.fine_decoder_gate_conv3 = nn.Conv2d(128,64, (1,1),(1,1)) 
        self.fine_decoder_gate_conv4 = nn.Conv2d(64,32,  (1,1),(1,1)) 
        self.fine_decoder_gate_conv5 = nn.Conv2d(64,32,  (1,1),(1,1))
        self.fine_decoder_gate_conv6 = nn.Conv2d(32,16,  (1,1),(1,1))
        self.fine_decoder_gate_conv7 = nn.Conv2d(32,16,  (1,1),(1,1))
        self.chanel_dense = nn.Linear(in_features=70, out_features=70)
        self.freq_dense = nn.Linear(in_features=161*2, out_features=161*2)
        self.df = DeepFilter(3,3,1)

    def forward(self, mic, ref):
        return self.input_forward(mic, ref)

    def input_forward(self, mic, ref):
        stft_mic = self.stft(mic)
        real_mic = stft_mic[:, :161]
        imag_mic = stft_mic[:, 161:]
        spec_mags_mic = torch.sqrt(real_mic ** 2 + imag_mic ** 2 + 1e-8)
        compressed_mags_mic = torch.pow(spec_mags_mic, 0.3)

        stft_ref = self.stft(ref)
        real_ref = stft_ref[:, :161]
        imag_ref = stft_ref[:, 161:]
        compressed_mags_ref = torch.pow(real_ref**2 + imag_ref**2 + 1e-8, 0.3*0.5)


        coarse_spec_mags = torch.stack([compressed_mags_ref, compressed_mags_mic], dim=1)  #(B,2,161,T)
        coarse_encoder_conv_0  = self.coarse_encoder_conv0(coarse_spec_mags)               #(B,16,161,T)
        coarse_encoder_conv_1  = self.coarse_encoder_conv1(coarse_encoder_conv_0)          #(B,16,161,T)
        coarse_encoder_conv_2  = self.coarse_encoder_conv2(coarse_encoder_conv_1)          #(B,16,81,T)
        coarse_encoder_conv_3  = self.coarse_encoder_conv3(coarse_encoder_conv_2)          #(B,32,41,T)
        coarse_encoder_conv_4  = self.coarse_encoder_conv4(coarse_encoder_conv_3)          #(B,32,21,T)
        coarse_encoder_conv_5  = self.coarse_encoder_conv5(coarse_encoder_conv_4)          #(B,32,11,T)
        coarse_encoder_conv_6  = self.coarse_encoder_conv6(coarse_encoder_conv_5)          #(B,32,6,T)
        coarse_encoder_conv_7  = self.coarse_encoder_conv7(coarse_encoder_conv_6)          #(B,32,6,T)

        ftgru_coarse_out = self.ftgru_coarse(coarse_encoder_conv_7)

        vad_out = self.vad(ftgru_coarse_out)

        coarse_decoder_gate_conv_0 = torch.cat([ftgru_coarse_out, coarse_encoder_conv_7] , 1)
        coarse_decoder_gate_conv_0 = torch.tanh(self.coarse_decoder_gate_conv0(coarse_decoder_gate_conv_0))
        coarse_decoder_conv_0 = coarse_decoder_gate_conv_0 * ftgru_coarse_out
        coarse_decoder_conv_0 = self.coarse_decoder_conv0(coarse_decoder_conv_0)

        coarse_decoder_gate_conv_1 = torch.cat([coarse_decoder_conv_0, coarse_encoder_conv_6] , 1)
        coarse_decoder_gate_conv_1 = torch.tanh(self.coarse_decoder_gate_conv1(coarse_decoder_gate_conv_1))
        coarse_decoder_conv_1 = coarse_decoder_gate_conv_1 * coarse_decoder_conv_0
        coarse_decoder_conv_1 = self.coarse_decoder_conv1(coarse_decoder_conv_1)

        coarse_decoder_gate_conv_2 = torch.cat([coarse_decoder_conv_1, coarse_encoder_conv_5] , 1)
        coarse_decoder_gate_conv_2 = torch.tanh(self.coarse_decoder_gate_conv2(coarse_decoder_gate_conv_2))
        coarse_decoder_conv_2 = coarse_decoder_gate_conv_2 * coarse_decoder_conv_1
        coarse_decoder_conv_2 = self.coarse_decoder_conv2(coarse_decoder_conv_2)

        coarse_decoder_gate_conv_3 = torch.cat([coarse_decoder_conv_2, coarse_encoder_conv_4] , 1)
        coarse_decoder_gate_conv_3 = torch.tanh(self.coarse_decoder_gate_conv3(coarse_decoder_gate_conv_3))
        coarse_decoder_conv_3 = coarse_decoder_gate_conv_3 * coarse_decoder_conv_2
        coarse_decoder_conv_3 = self.coarse_decoder_conv3(coarse_decoder_conv_3)

        coarse_decoder_gate_conv_4 = torch.cat([coarse_decoder_conv_3, coarse_encoder_conv_3] , 1)
        coarse_decoder_gate_conv_4 = torch.tanh(self.coarse_decoder_gate_conv4(coarse_decoder_gate_conv_4))
        coarse_decoder_conv_4 = coarse_decoder_gate_conv_4 * coarse_decoder_conv_3
        coarse_decoder_conv_4 = self.coarse_decoder_conv4(coarse_decoder_conv_4)

        coarse_decoder_gate_conv_5 = torch.cat([coarse_decoder_conv_4, coarse_encoder_conv_2] , 1)
        coarse_decoder_gate_conv_5 = torch.tanh(self.coarse_decoder_gate_conv5(coarse_decoder_gate_conv_5))
        coarse_decoder_conv_5 = coarse_decoder_gate_conv_5 * coarse_decoder_conv_4
        coarse_decoder_conv_5 = self.coarse_decoder_conv5(coarse_decoder_conv_5)


        coarse_decoder_gate_conv_6 = torch.cat([coarse_decoder_conv_5, coarse_encoder_conv_1] , 1)
        coarse_decoder_gate_conv_6 = torch.tanh(self.coarse_decoder_gate_conv6(coarse_decoder_gate_conv_6))
        coarse_decoder_conv_6 = coarse_decoder_gate_conv_6 * coarse_decoder_conv_5
        coarse_decoder_conv_6 = self.coarse_decoder_conv6(coarse_decoder_conv_6)

        coarse_decoder_gate_conv_7 = torch.cat([coarse_decoder_conv_6, coarse_encoder_conv_0] , 1)
        coarse_decoder_gate_conv_7 = torch.tanh(self.coarse_decoder_gate_conv7(coarse_decoder_gate_conv_7))
        coarse_decoder_conv_7 = coarse_decoder_gate_conv_7 * coarse_decoder_conv_6
        coarse_decoder_conv_7 = self.coarse_decoder_conv7(coarse_decoder_conv_7)

        coarse_mask_out = coarse_decoder_conv_7.permute(0,3,1,2)
        B,T,C,D = coarse_mask_out.size()
        coarse_mask_out = torch.reshape(coarse_mask_out,[B,T, -1])
        coarse_mask_out = torch.sigmoid(self.coarse_dense(coarse_mask_out))
        coarse_mask_out = coarse_mask_out.permute(0,2,1)
        real_coarse_mask_out = coarse_mask_out[:, :161]
        imag_coarse_mask_out = coarse_mask_out[:, 161:]

        coarse_enhanced_real = real_coarse_mask_out * real_mic - imag_coarse_mask_out * imag_mic
        coarse_enhanced_imag = real_coarse_mask_out * imag_mic + imag_coarse_mask_out * real_mic
        spec_mags_coarse_out   = torch.sqrt(coarse_enhanced_real ** 2 + coarse_enhanced_imag ** 2 + 1e-8)
        compressed_coarse_mags = torch.pow(spec_mags_coarse_out, 0.3)
        fine_spec_mags = torch.stack([compressed_mags_ref, compressed_coarse_mags, compressed_mags_mic], dim=1)

        fine_encoder_conv_0  = self.fine_encoder_conv0(fine_spec_mags)              
        fine_encoder_conv_1  = self.fine_encoder_conv1(fine_encoder_conv_0)   
        fine_encoder_conv_2  = self.fine_encoder_conv2(fine_encoder_conv_1)   
        fine_encoder_conv_3  = self.fine_encoder_conv3(fine_encoder_conv_2)   
        fine_encoder_conv_4  = self.fine_encoder_conv4(fine_encoder_conv_3)   
        fine_encoder_conv_5  = self.fine_encoder_conv5(fine_encoder_conv_4)   
        fine_encoder_conv_6  = self.fine_encoder_conv6(fine_encoder_conv_5)   
        fine_encoder_conv_7  = self.fine_encoder_conv7(fine_encoder_conv_6) 

        ftgru_fine_out = self.ftgru_fine(fine_encoder_conv_7)    

        fine_decoder_gate_conv_0 = torch.cat([ftgru_fine_out, fine_encoder_conv_7] , 1)
        fine_decoder_gate_conv_0 = torch.tanh(self.fine_decoder_gate_conv0(fine_decoder_gate_conv_0))
        fine_decoder_conv_0 = fine_decoder_gate_conv_0 * ftgru_fine_out
        fine_decoder_conv_0 = self.fine_decoder_conv0(fine_decoder_conv_0)

        fine_decoder_gate_conv_1 = torch.cat([fine_decoder_conv_0, fine_encoder_conv_6] , 1)
        fine_decoder_gate_conv_1 = torch.tanh(self.fine_decoder_gate_conv1(fine_decoder_gate_conv_1))
        fine_decoder_conv_1 = fine_decoder_gate_conv_1 * fine_decoder_conv_0
        fine_decoder_conv_1 = self.fine_decoder_conv1(fine_decoder_conv_1)

        fine_decoder_gate_conv_2 = torch.cat([fine_decoder_conv_1, fine_encoder_conv_5] , 1)
        fine_decoder_gate_conv_2 = torch.tanh(self.fine_decoder_gate_conv2(fine_decoder_gate_conv_2))
        fine_decoder_conv_2 = fine_decoder_gate_conv_2 * fine_decoder_conv_1
        fine_decoder_conv_2 = self.fine_decoder_conv2(fine_decoder_conv_2)

        fine_decoder_gate_conv_3 = torch.cat([fine_decoder_conv_2, fine_encoder_conv_4] , 1)
        fine_decoder_gate_conv_3 = torch.tanh(self.fine_decoder_gate_conv3(fine_decoder_gate_conv_3))
        fine_decoder_conv_3 = fine_decoder_gate_conv_3 * fine_decoder_conv_2
        fine_decoder_conv_3 = self.fine_decoder_conv3(fine_decoder_conv_3)

        fine_decoder_gate_conv_4 = torch.cat([fine_decoder_conv_3, fine_encoder_conv_3] , 1)
        fine_decoder_gate_conv_4 = torch.tanh(self.fine_decoder_gate_conv4(fine_decoder_gate_conv_4))
        fine_decoder_conv_4 = fine_decoder_gate_conv_4 * fine_decoder_conv_3
        fine_decoder_conv_4 = self.fine_decoder_conv4(fine_decoder_conv_4)

        fine_decoder_gate_conv_5 = torch.cat([fine_decoder_conv_4, fine_encoder_conv_2] , 1)
        fine_decoder_gate_conv_5 = torch.tanh(self.fine_decoder_gate_conv5(fine_decoder_gate_conv_5))
        fine_decoder_conv_5 = fine_decoder_gate_conv_5 * fine_decoder_conv_4
        fine_decoder_conv_5 = self.fine_decoder_conv5(fine_decoder_conv_5)

        fine_decoder_gate_conv_6 = torch.cat([fine_decoder_conv_5, fine_encoder_conv_1] , 1)
        fine_decoder_gate_conv_6 = torch.tanh(self.fine_decoder_gate_conv6(fine_decoder_gate_conv_6))
        fine_decoder_conv_6 = fine_decoder_gate_conv_6 * fine_decoder_conv_5
        fine_decoder_conv_6 = self.fine_decoder_conv6(fine_decoder_conv_6)

        fine_decoder_gate_conv_7 = torch.cat([fine_decoder_conv_6, fine_encoder_conv_0] , 1)
        fine_decoder_gate_conv_7 = torch.tanh(self.fine_decoder_gate_conv7(fine_decoder_gate_conv_7))
        fine_decoder_conv_7 = fine_decoder_gate_conv_7 * fine_decoder_conv_6
        fine_decoder_conv_7 = self.fine_decoder_conv7(fine_decoder_conv_7)
        chanel_dense_in = fine_decoder_conv_7.permute(0,3,2,1)
        chanel_dense_in = self.chanel_dense(chanel_dense_in)
        B,T,D,C  = chanel_dense_in.size()
        freq_dense_in   =  chanel_dense_in.reshape(B,T,D,35,2)
        freq_dense_in   =  freq_dense_in.permute(0,3,1,2,4)
        freq_dense_in = freq_dense_in.reshape(B,35,T,-1)
        fine_mask_out = torch.sigmoid(self.freq_dense(freq_dense_in))
        real_fine_mask_out = fine_mask_out[..., :161].permute(0,1,3,2)
        imag_fine_mask_out = fine_mask_out[..., 161:].permute(0,1,3,2) 

        df_inputs = [coarse_enhanced_real,coarse_enhanced_imag]
        fine_enhanced_real, fine_enhanced_imag = self.df(df_inputs,real_fine_mask_out, imag_fine_mask_out)
        return coarse_enhanced_real, coarse_enhanced_imag, fine_enhanced_real, fine_enhanced_imag, vad_out  

    def compute_loss(self,coarse_enhanced_real, coarse_enhanced_imag, fine_enhanced_real, fine_enhanced_imag, pref_vad, clean_real, clean_imag, vad_label):
        clean_spec_mags   = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-8)
        vad_ce = F.cross_entropy(pref_vad, vad_label.long())
        coarse_real_mae  = F.l1_loss(coarse_enhanced_real, clean_real)
        coarse_imag_mae  = F.l1_loss(coarse_enhanced_imag, clean_imag)
        coarse_spec_mags = torch.sqrt(coarse_enhanced_real ** 2 + coarse_enhanced_imag ** 2 + 1e-8)
        coarse_spec_mae  = F.l1_loss(coarse_spec_mags, clean_spec_mags)
        coarse_loss = coarse_real_mae + coarse_imag_mae + coarse_spec_mae
        fine_real_mae  = F.l1_loss(fine_enhanced_real, clean_real)
        fine_imag_mae  = F.l1_loss(fine_enhanced_imag, clean_imag)
        fine_spec_mags = torch.sqrt(fine_enhanced_real ** 2 + fine_enhanced_imag ** 2 + 1e-8)
        fine_spec_mae  = F.l1_loss(fine_spec_mags, clean_spec_mags)
        fine_loss = fine_real_mae + fine_imag_mae + fine_spec_mae
        loss = 0,3*coarse_loss + 0.7*fine_loss + 0.06*vad_ce
        return loss


def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

def print_parameter(model):
    model_state_dict = model.state_dict()
    for name, tensor in model_state_dict.items():
        print('{}\t{}\t{}'.format(name, tensor.shape, tensor.numel())) 


if __name__ == '__main__':
    time_len = 16000*2
    batch    = 10
    clean    = torch.randn(batch, time_len)
    mic      = torch.randn(batch, time_len)
    ref      = torch.randn(batch, time_len)
    model    = TSPNN()
    result = get_parameter_number(model)
    # print_parameter(model)
    print('Number of parameter: \n\t total: {:.2f} M, '
          'trainable: {:.2f} M'.format(result['Total'] / 1e6, result['Trainable'] / 1e6))
    coarse_real, coarse_imag, fine_real, fine_imag, pred_vad = model(mic,ref)
    stft  = ConvSTFT(320, 160, 320,'hann', 'complex')
    stft_clean = stft(clean)
    real_clean = stft_clean[:, :161]
    imag_clean = stft_clean[:, 161:]
    B, _, T = real_clean.size()
    pred_vad = torch.reshape(pred_vad,[B*T,2])
    vad_label   = torch.reshape(torch.randint(low=0, high=2, size=(B, T)),[B*T,1]).squeeze(1)
    loss = model.compute_loss(coarse_real, coarse_imag, fine_real, fine_imag, pred_vad,real_clean,imag_clean,vad_label)
    print('Hello world!')

commented

厉害 👍 不过如果想用作实时推理的话,看起来padding似乎有点问题,也不是因果系统。

有个问题请教大家,预测的complex ideal ratio mask的理论真值范围是(-inf, inf),而模型中使用sigmoid做激活只能建模(0,1),不会引起很大的偏差吗?

有个问题请教大家,预测的complex ideal ratio mask的理论真值范围是(-inf, inf),而模型中使用sigmoid做激活只能建模(0,1),不会引起很大的偏差吗?

@nicriverhoo hi,最近还在关注这篇work吗~

有个问题请教大家,预测的complex ideal ratio mask的理论真值范围是(-inf, inf),而模型中使用sigmoid做激活只能建模(0,1),不会引起很大的偏差吗?

@c8x1 hi,请问你之前复现的结果怎么样,也是使用的sigmoid做激活吗~