请问这个模型会开源吗?
KollyYang opened this issue · comments
我实现了一个 参数总数和作者的有些不一样,我的只有1.26M 参数,@enhancer12
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.signal import get_window
def init_kernels(win_len, fft_len, win_type=None, invers=False):
if win_type == 'None' or win_type is None:
window = np.ones(win_len)
else:
window = get_window(win_type, win_len, fftbins=True) # **0.5
N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
if invers:
kernel = np.linalg.pinv(kernel).T
kernel = kernel * window
kernel = kernel[:, None, :]
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))
class ConvSTFT(nn.Module):
def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
super(ConvSTFT, self).__init__()
if fft_len == None:
self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, _ = init_kernels(win_len, self.fft_len, win_type)
self.register_buffer('weight', kernel)
self.feature_type = feature_type
self.stride = win_inc
self.win_len = win_len
self.dim = self.fft_len
def forward(self, inputs):
if inputs.dim() == 2:
inputs = torch.unsqueeze(inputs, 1)
inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
outputs = F.conv1d(inputs, self.weight, stride=self.stride)
if self.feature_type == 'complex':
return outputs
else:
dim = self.dim // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
mags = torch.sqrt(real ** 2 + imag ** 2)
phase = torch.atan2(imag, real)
return mags, phase
class ConviSTFT(nn.Module):
def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
super(ConviSTFT, self).__init__()
if fft_len == None:
self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, window = init_kernels(win_len, self.fft_len, win_type, invers=True)
self.register_buffer('weight', kernel)
self.feature_type = feature_type
self.win_type = win_type
self.win_len = win_len
self.stride = win_inc
self.dim = self.fft_len
self.register_buffer('window', window)
self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
def forward(self, inputs, phase=None):
if phase is not None:
real = inputs * torch.cos(phase)
imag = inputs * torch.sin(phase)
inputs = torch.cat([real, imag], 1)
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
outputs = outputs / (coff + 1e-8)
outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]
return outputs
class CausalConv(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride):
super(CausalConv, self).__init__()
self.stride = stride
self.kernel_size = kernel_size
self.out_ch = out_ch
self.in_ch = in_ch
self.left_pad = kernel_size[1] - 1
padding = (kernel_size[0] // 2, self.left_pad)
self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size, stride=stride,
padding=padding)
self.norm = nn.BatchNorm2d(out_ch)
self.activation = nn.PReLU()
def forward(self, x):
B, C, F, T = x.size()
out = self.conv(x)[..., :T]
out = self.norm(out)
out = self.activation(out)
return out
class CausalTransConv(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, output_padding = (0,0)):
super(CausalTransConv, self).__init__()
padding = (kernel_size[0] // 2, 0)
self.trans_conv = nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=output_padding)
self.norm = nn.BatchNorm2d(out_ch)
self.activation = nn.PReLU()
def forward(self, x):
B, C, F, T = x.size()
out = self.trans_conv(x)[..., :T]
out = self.norm(out)
out = self.activation(out)
return out
class DeepFilter(nn.Module):
def __init__(self, F_neighbors, T_pastframes, T_lookahead):
'''
inputs would be [B,F,T]
'''
super(DeepFilter, self).__init__()
self.F_neighbors = F_neighbors
self.T_pastframes = T_pastframes
self.T_lookahead = T_lookahead
assert(self.T_pastframes>=self.T_lookahead)
t_width = self.T_pastframes + self.T_lookahead +1
f_width = self.F_neighbors*2+1
kernel = torch.eye(t_width*f_width)
self.register_buffer('kernel', torch.reshape(kernel, [t_width*f_width, 1, f_width, t_width]))
def forward(self, inputs, filters):
'''
inputs is [real, imag]: [ [B,F,T], [B,F,T] ]
filters is [real, imag]: [ [B,F,T], [B,F,T] ]
'''
chunked_inputs = torch.cat(inputs,0)[:,None]
pad = (self.T_pastframes - self.T_lookahead, 0, 0, 0)
conv_pad = nn.ConstantPad2d(pad, 0.0)
chunked_inputs = conv_pad(chunked_inputs)
chunked_inputs = F.conv2d(
chunked_inputs,
self.kernel,
padding= [self.F_neighbors, self.T_lookahead],
)
inputs_r, inputs_i = torch.chunk(chunked_inputs, 2, 0)
chunked_filters = torch.cat(filters,0)[:,None]
chunked_filters = conv_pad(chunked_filters)
chunked_filters = F.conv2d(
chunked_filters,
self.kernel,
padding= [self.F_neighbors, self.T_lookahead],
)
filters_r, filters_i = torch.chunk(chunked_filters, 2, 0)
outputs_r = inputs_r*filters_r - inputs_i*filters_i
outputs_i = inputs_r*filters_i + inputs_r*filters_i
outputs_r = torch.sum(outputs_r, 1)
outputs_i = torch.sum(outputs_i, 1)
return outputs_r, outputs_i
class DPRnn(torch.nn.Module):
def __init__(self, hidden_ch_F, hidden_ch_T, F_dim, input_ch):
super(DPRnn, self).__init__()
self.F_dim = F_dim
self.input_size = input_ch
self.hidden_F = hidden_ch_F
self.hidden_T = hidden_ch_T
self.intra_rnn = torch.nn.GRU(input_size=self.input_size, hidden_size= self.hidden_F // 2, bidirectional=True,
batch_first=True)
self.intra_fc = torch.nn.Linear(in_features=self.hidden_F, out_features=self.input_size)
self.intra_ln = torch.nn.LayerNorm([F_dim, self.input_size])
self.inter_rnn = torch.nn.GRU(input_size=self.input_size, hidden_size=self.hidden_T, batch_first=True)
self.inter_fc = torch.nn.Linear(in_features=self.hidden_T, out_features=self.input_size)
self.inter_ln = torch.nn.LayerNorm([F_dim, self.input_size])
def forward(self, x):
"""
:param x: B,C,F,T
:return:
"""
B, C, F, T = x.size()
x = x.permute(0, 3, 2, 1) # B,T,F,C
intra_in = torch.reshape(x, [B * T, F, C])
intra_rnn_out, _ = self.intra_rnn(intra_in)
intra_out = self.intra_ln(torch.reshape(self.intra_fc(intra_rnn_out), [B, T, F, C])) # B,T,F,C
intra_out = x + intra_out # B,T,F,C
inter_in = intra_out.permute(0, 2, 1, 3) # B,F,T,C
inter_in = torch.reshape(inter_in, [B * F, T, C])
inter_rnn_out, _ = self.inter_rnn(inter_in)
inter_out = self.inter_ln(
torch.reshape(self.inter_fc(inter_rnn_out), [B, F, T, C]).permute(0, 2, 1, 3)) # B,T,F,C
out = (intra_out + inter_out).permute(0, 3, 2, 1)
return out
class FTGRU(torch.nn.Module):
def __init__(self, hidden_ch_F, hidden_ch_T, input_ch):
super(FTGRU, self).__init__()
self.input_size = input_ch
self.hidden_F = hidden_ch_F
self.hidden_T = hidden_ch_T
self.f_gru0 = torch.nn.GRU(input_size=self.input_size, hidden_size= self.hidden_F // 2, bidirectional=True,
batch_first=True)
self.t_gru0 = torch.nn.GRU(input_size=self.hidden_F, hidden_size=self.hidden_T, batch_first=True)
self.norm0 = nn.BatchNorm2d(self.hidden_T)
self.activation0 = nn.PReLU()
self.f_gru1 = torch.nn.GRU(input_size=self.hidden_T, hidden_size= self.hidden_F // 2, bidirectional=True,
batch_first=True)
self.t_gru1 = torch.nn.GRU(input_size=self.hidden_F, hidden_size=input_ch, batch_first=True)
self.norm1 = nn.BatchNorm2d(input_ch)
self.activation1 = nn.PReLU()
def forward(self, x):
"""
:param x: B,C,F,T
:return:
"""
B, C, F, T = x.size()
x = x.permute(0, 3, 2, 1) # B,T,F,C
input0 = torch.reshape(x, [B * T, F, C])
f_gru0_out, _ = self.f_gru0(input0)
f_gru0_out = torch.reshape(f_gru0_out, [B, T, F, -1])
t_gru0_in = f_gru0_out.permute(0, 2, 1, 3)
t_gru0_in = torch.reshape(t_gru0_in, [B * F, T, -1])
t_gru0_out, _ = self.t_gru0(t_gru0_in)
t_gru0_out = torch.reshape(t_gru0_out, [B, F, T, -1])
norm0_in = t_gru0_out.permute(0, 3, 1, 2)
norm0_out = self.norm0(norm0_in)
activation0_out = self.activation0(norm0_out) #B,C,F,T
x = activation0_out.permute(0, 3, 2, 1) # B,T,F,C
input1 = torch.reshape(x, [B * T, F, -1])
f_gru1_out, _ = self.f_gru1(input1)
f_gru1_out = torch.reshape(f_gru1_out, [B, T, F, -1])
t_gru1_in = f_gru1_out.permute(0, 2, 1, 3)
t_gru1_in = torch.reshape(t_gru1_in, [B * F, T, -1])
t_gru1_out, _ = self.t_gru1(t_gru1_in)
t_gru1_out = torch.reshape(t_gru1_out, [B, F, T, C])
norm1_in = t_gru1_out.permute(0, 3, 1, 2)
norm1_out = self.norm1(norm1_in)
activation1_out = self.activation1(norm1_out)
return activation1_out
class VAD(nn.Module):
def __init__(self):
super(VAD, self).__init__()
self.conv0 = CausalConv(32,16, (1,1),(1,1))
self.f_gru = torch.nn.GRU(input_size=16, hidden_size= 8, bidirectional=True, batch_first=True)
self.conv1d_0 = nn.Conv1d(16,16,kernel_size=1,stride=1, bias=False)
self.norm_0 = nn.BatchNorm1d(16)
self.activation_0 = nn.PReLU()
self.conv1d_1 = nn.Conv1d(16,2,kernel_size=1,stride=1, bias=False)
def forward(self, x):
conv0_out = self.conv0(x)
B, C, F, T = conv0_out.size()
conv0_out = conv0_out.permute(0, 3, 2, 1) # B,T,F,C
intra_in = torch.reshape(conv0_out, [B * T, F, C])
_, f_gru_h = self.f_gru(intra_in)
h_0, _, h_2 = f_gru_h.size()
f_gru_h = torch.reshape(f_gru_h.permute(1,0,2),[B,T,h_0,h_2])
f_gru_h = torch.reshape(f_gru_h,[B,T,-1]).permute(0,2,1)
conv1d_0_out = self.conv1d_0(f_gru_h)
norm_out = self.norm_0(conv1d_0_out)
activation_out = self.activation_0(norm_out)
conv1d_1_out = self.conv1d_1(activation_out).permute(0,2,1) #B,T,2
return conv1d_1_out
class TSPNN(nn.Module):
def __init__(self):
super(TSPNN, self).__init__()
self.stft = ConvSTFT(320, 160, 320,'hann', 'complex')
self.istft = ConviSTFT(320, 160, 320,'hann', 'complex')
self.coarse_encoder_conv0 = CausalConv(2,16, (5,1),(1,1))
self.coarse_encoder_conv1 = CausalConv(16,16,(1,5),(1,1))
self.coarse_encoder_conv2 = CausalConv(16,16,(6,5),(2,1))
self.coarse_encoder_conv3 = CausalConv(16,32,(4,3),(2,1))
self.coarse_encoder_conv4 = CausalConv(32,32,(6,5),(2,1))
self.coarse_encoder_conv5 = CausalConv(32,32,(5,3),(2,1))
self.coarse_encoder_conv6 = CausalConv(32,32,(3,5),(2,1))
self.coarse_encoder_conv7 = CausalConv(32,32,(3,3),(1,1))
self.ftgru_coarse = FTGRU(64,64,32)
self.vad = VAD()
self.coarse_decoder_conv0 = CausalTransConv(32,32, (3,3),(1,1))
self.coarse_decoder_conv1 = CausalTransConv(32,32, (3,5),(2,1))
self.coarse_decoder_conv2 = CausalTransConv(32,32, (5,3),(2,1))
self.coarse_decoder_conv3 = CausalTransConv(32,32, (6,5),(2,1),(1,0))
self.coarse_decoder_conv4 = CausalTransConv(32,16, (4,3),(2,1),(1,0))
self.coarse_decoder_conv5 = CausalTransConv(16,16, (6,5),(2,1),(1,0))
self.coarse_decoder_conv6 = CausalTransConv(16,16, (1,5),(1,1))
self.coarse_decoder_conv7 = CausalTransConv(16,2, (5,1),(1,1))
self.coarse_decoder_gate_conv0 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv1 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv2 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv3 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv4 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv5 = nn.Conv2d(32,16, (1,1),(1,1))
self.coarse_decoder_gate_conv6 = nn.Conv2d(32,16, (1,1),(1,1))
self.coarse_decoder_gate_conv7 = nn.Conv2d(32,16, (1,1),(1,1))
self.coarse_dense = nn.Linear(in_features=161*2, out_features=161*2)
self.fine_encoder_conv0 = CausalConv(3,16, (5,1),(1,1))
self.fine_encoder_conv1 = CausalConv(16,16,(1,5),(1,1))
self.fine_encoder_conv2 = CausalConv(16,32,(6,5),(2,1))
self.fine_encoder_conv3 = CausalConv(32,32,(4,3),(2,1))
self.fine_encoder_conv4 = CausalConv(32,64,(6,5),(2,1))
self.fine_encoder_conv5 = CausalConv(64,64,(5,3),(2,1))
self.fine_encoder_conv6 = CausalConv(64,64,(3,5),(2,1))
self.fine_encoder_conv7 = CausalConv(64,64,(3,3),(1,1))
self.ftgru_fine = FTGRU(128,128,64)
self.fine_decoder_conv0 = CausalTransConv(64,64, (3,3),(1,1))
self.fine_decoder_conv1 = CausalTransConv(64,64, (3,5),(2,1))
self.fine_decoder_conv2 = CausalTransConv(64,64, (5,3),(2,1))
self.fine_decoder_conv3 = CausalTransConv(64,32, (6,5),(2,1),(1,0))
self.fine_decoder_conv4 = CausalTransConv(32,32, (4,3),(2,1),(1,0))
self.fine_decoder_conv5 = CausalTransConv(32,16, (6,5),(2,1),(1,0))
self.fine_decoder_conv6 = CausalTransConv(16,16, (1,5),(1,1))
self.fine_decoder_conv7 = CausalTransConv(16,2, (5,1),(1,1))
self.fine_decoder_gate_conv0 = nn.Conv2d(128,64, (1,1),(1,1))
self.fine_decoder_gate_conv1 = nn.Conv2d(128,64, (1,1),(1,1))
self.fine_decoder_gate_conv2 = nn.Conv2d(128,64, (1,1),(1,1))
self.fine_decoder_gate_conv3 = nn.Conv2d(128,64, (1,1),(1,1))
self.fine_decoder_gate_conv4 = nn.Conv2d(64,32, (1,1),(1,1))
self.fine_decoder_gate_conv5 = nn.Conv2d(64,32, (1,1),(1,1))
self.fine_decoder_gate_conv6 = nn.Conv2d(32,16, (1,1),(1,1))
self.fine_decoder_gate_conv7 = nn.Conv2d(32,16, (1,1),(1,1))
self.fine_dense = nn.Linear(in_features=161*2, out_features=161*2)
self.df = DeepFilter(3,3,1)
def forward(self, mic, ref):
return self.input_forward(mic, ref)
def input_forward(self, mic, ref):
stft_mic = self.stft(mic)
real_mic = stft_mic[:, :161]
imag_mic = stft_mic[:, 161:]
spec_mags_mic = torch.sqrt(real_mic ** 2 + imag_mic ** 2 + 1e-8)
compressed_mags_mic = torch.pow(spec_mags_mic, 0.3)
stft_ref = self.stft(ref)
real_ref = stft_ref[:, :161]
imag_ref = stft_ref[:, 161:]
compressed_mags_ref = torch.pow(real_ref**2 + imag_ref**2 + 1e-8, 0.3*0.5)
coarse_spec_mags = torch.stack([compressed_mags_ref, compressed_mags_mic], dim=1) #(B,2,161,T)
coarse_encoder_conv_0 = self.coarse_encoder_conv0(coarse_spec_mags) #(B,16,161,T)
coarse_encoder_conv_1 = self.coarse_encoder_conv1(coarse_encoder_conv_0) #(B,16,161,T)
coarse_encoder_conv_2 = self.coarse_encoder_conv2(coarse_encoder_conv_1) #(B,16,81,T)
coarse_encoder_conv_3 = self.coarse_encoder_conv3(coarse_encoder_conv_2) #(B,32,41,T)
coarse_encoder_conv_4 = self.coarse_encoder_conv4(coarse_encoder_conv_3) #(B,32,21,T)
coarse_encoder_conv_5 = self.coarse_encoder_conv5(coarse_encoder_conv_4) #(B,32,11,T)
coarse_encoder_conv_6 = self.coarse_encoder_conv6(coarse_encoder_conv_5) #(B,32,6,T)
coarse_encoder_conv_7 = self.coarse_encoder_conv7(coarse_encoder_conv_6) #(B,32,6,T)
ftgru_coarse_out = self.ftgru_coarse(coarse_encoder_conv_7)
vad_out = self.vad(ftgru_coarse_out)
coarse_decoder_gate_conv_0 = torch.cat([ftgru_coarse_out, coarse_encoder_conv_7] , 1)
coarse_decoder_gate_conv_0 = torch.tanh(self.coarse_decoder_gate_conv0(coarse_decoder_gate_conv_0))
coarse_decoder_conv_0 = coarse_decoder_gate_conv_0 * ftgru_coarse_out
coarse_decoder_conv_0 = self.coarse_decoder_conv0(coarse_decoder_conv_0)
coarse_decoder_gate_conv_1 = torch.cat([coarse_decoder_conv_0, coarse_encoder_conv_6] , 1)
coarse_decoder_gate_conv_1 = torch.tanh(self.coarse_decoder_gate_conv1(coarse_decoder_gate_conv_1))
coarse_decoder_conv_1 = coarse_decoder_gate_conv_1 * coarse_decoder_conv_0
coarse_decoder_conv_1 = self.coarse_decoder_conv1(coarse_decoder_conv_1)
coarse_decoder_gate_conv_2 = torch.cat([coarse_decoder_conv_1, coarse_encoder_conv_5] , 1)
coarse_decoder_gate_conv_2 = torch.tanh(self.coarse_decoder_gate_conv2(coarse_decoder_gate_conv_2))
coarse_decoder_conv_2 = coarse_decoder_gate_conv_2 * coarse_decoder_conv_1
coarse_decoder_conv_2 = self.coarse_decoder_conv2(coarse_decoder_conv_2)
coarse_decoder_gate_conv_3 = torch.cat([coarse_decoder_conv_2, coarse_encoder_conv_4] , 1)
coarse_decoder_gate_conv_3 = torch.tanh(self.coarse_decoder_gate_conv3(coarse_decoder_gate_conv_3))
coarse_decoder_conv_3 = coarse_decoder_gate_conv_3 * coarse_decoder_conv_2
coarse_decoder_conv_3 = self.coarse_decoder_conv3(coarse_decoder_conv_3)
coarse_decoder_gate_conv_4 = torch.cat([coarse_decoder_conv_3, coarse_encoder_conv_3] , 1)
coarse_decoder_gate_conv_4 = torch.tanh(self.coarse_decoder_gate_conv4(coarse_decoder_gate_conv_4))
coarse_decoder_conv_4 = coarse_decoder_gate_conv_4 * coarse_decoder_conv_3
coarse_decoder_conv_4 = self.coarse_decoder_conv4(coarse_decoder_conv_4)
coarse_decoder_gate_conv_5 = torch.cat([coarse_decoder_conv_4, coarse_encoder_conv_2] , 1)
coarse_decoder_gate_conv_5 = torch.tanh(self.coarse_decoder_gate_conv5(coarse_decoder_gate_conv_5))
coarse_decoder_conv_5 = coarse_decoder_gate_conv_5 * coarse_decoder_conv_4
coarse_decoder_conv_5 = self.coarse_decoder_conv5(coarse_decoder_conv_5)
coarse_decoder_gate_conv_6 = torch.cat([coarse_decoder_conv_5, coarse_encoder_conv_1] , 1)
coarse_decoder_gate_conv_6 = torch.tanh(self.coarse_decoder_gate_conv6(coarse_decoder_gate_conv_6))
coarse_decoder_conv_6 = coarse_decoder_gate_conv_6 * coarse_decoder_conv_5
coarse_decoder_conv_6 = self.coarse_decoder_conv6(coarse_decoder_conv_6)
coarse_decoder_gate_conv_7 = torch.cat([coarse_decoder_conv_6, coarse_encoder_conv_0] , 1)
coarse_decoder_gate_conv_7 = torch.tanh(self.coarse_decoder_gate_conv7(coarse_decoder_gate_conv_7))
coarse_decoder_conv_7 = coarse_decoder_gate_conv_7 * coarse_decoder_conv_6
coarse_decoder_conv_7 = self.coarse_decoder_conv7(coarse_decoder_conv_7)
coarse_mask_out = coarse_decoder_conv_7.permute(0,3,1,2)
B,T,C,D = coarse_mask_out.size()
coarse_mask_out = torch.reshape(coarse_mask_out,[B,T, -1])
coarse_mask_out = torch.sigmoid(self.coarse_dense(coarse_mask_out))
coarse_mask_out = coarse_mask_out.permute(0,2,1)
real_coarse_mask_out = coarse_mask_out[:, :161]
imag_coarse_mask_out = coarse_mask_out[:, 161:]
coarse_enhanced_real = real_coarse_mask_out * real_mic - imag_coarse_mask_out * imag_mic
coarse_enhanced_imag = real_coarse_mask_out * imag_mic + imag_coarse_mask_out * real_mic
spec_mags_coarse_out = torch.sqrt(coarse_enhanced_real ** 2 + coarse_enhanced_imag ** 2 + 1e-8)
compressed_coarse_mags = torch.pow(spec_mags_coarse_out, 0.3)
fine_spec_mags = torch.stack([compressed_mags_ref, compressed_coarse_mags, compressed_mags_mic], dim=1)
fine_encoder_conv_0 = self.fine_encoder_conv0(fine_spec_mags)
fine_encoder_conv_1 = self.fine_encoder_conv1(fine_encoder_conv_0)
fine_encoder_conv_2 = self.fine_encoder_conv2(fine_encoder_conv_1)
fine_encoder_conv_3 = self.fine_encoder_conv3(fine_encoder_conv_2)
fine_encoder_conv_4 = self.fine_encoder_conv4(fine_encoder_conv_3)
fine_encoder_conv_5 = self.fine_encoder_conv5(fine_encoder_conv_4)
fine_encoder_conv_6 = self.fine_encoder_conv6(fine_encoder_conv_5)
fine_encoder_conv_7 = self.fine_encoder_conv7(fine_encoder_conv_6)
ftgru_fine_out = self.ftgru_fine(fine_encoder_conv_7)
fine_decoder_gate_conv_0 = torch.cat([ftgru_fine_out, fine_encoder_conv_7] , 1)
fine_decoder_gate_conv_0 = torch.tanh(self.fine_decoder_gate_conv0(fine_decoder_gate_conv_0))
fine_decoder_conv_0 = fine_decoder_gate_conv_0 * ftgru_fine_out
fine_decoder_conv_0 = self.fine_decoder_conv0(fine_decoder_conv_0)
fine_decoder_gate_conv_1 = torch.cat([fine_decoder_conv_0, fine_encoder_conv_6] , 1)
fine_decoder_gate_conv_1 = torch.tanh(self.fine_decoder_gate_conv1(fine_decoder_gate_conv_1))
fine_decoder_conv_1 = fine_decoder_gate_conv_1 * fine_decoder_conv_0
fine_decoder_conv_1 = self.fine_decoder_conv1(fine_decoder_conv_1)
fine_decoder_gate_conv_2 = torch.cat([fine_decoder_conv_1, fine_encoder_conv_5] , 1)
fine_decoder_gate_conv_2 = torch.tanh(self.fine_decoder_gate_conv2(fine_decoder_gate_conv_2))
fine_decoder_conv_2 = fine_decoder_gate_conv_2 * fine_decoder_conv_1
fine_decoder_conv_2 = self.fine_decoder_conv2(fine_decoder_conv_2)
fine_decoder_gate_conv_3 = torch.cat([fine_decoder_conv_2, fine_encoder_conv_4] , 1)
fine_decoder_gate_conv_3 = torch.tanh(self.fine_decoder_gate_conv3(fine_decoder_gate_conv_3))
fine_decoder_conv_3 = fine_decoder_gate_conv_3 * fine_decoder_conv_2
fine_decoder_conv_3 = self.fine_decoder_conv3(fine_decoder_conv_3)
fine_decoder_gate_conv_4 = torch.cat([fine_decoder_conv_3, fine_encoder_conv_3] , 1)
fine_decoder_gate_conv_4 = torch.tanh(self.fine_decoder_gate_conv4(fine_decoder_gate_conv_4))
fine_decoder_conv_4 = fine_decoder_gate_conv_4 * fine_decoder_conv_3
fine_decoder_conv_4 = self.fine_decoder_conv4(fine_decoder_conv_4)
fine_decoder_gate_conv_5 = torch.cat([fine_decoder_conv_4, fine_encoder_conv_2] , 1)
fine_decoder_gate_conv_5 = torch.tanh(self.fine_decoder_gate_conv5(fine_decoder_gate_conv_5))
fine_decoder_conv_5 = fine_decoder_gate_conv_5 * fine_decoder_conv_4
fine_decoder_conv_5 = self.fine_decoder_conv5(fine_decoder_conv_5)
fine_decoder_gate_conv_6 = torch.cat([fine_decoder_conv_5, fine_encoder_conv_1] , 1)
fine_decoder_gate_conv_6 = torch.tanh(self.fine_decoder_gate_conv6(fine_decoder_gate_conv_6))
fine_decoder_conv_6 = fine_decoder_gate_conv_6 * fine_decoder_conv_5
fine_decoder_conv_6 = self.fine_decoder_conv6(fine_decoder_conv_6)
fine_decoder_gate_conv_7 = torch.cat([fine_decoder_conv_6, fine_encoder_conv_0] , 1)
fine_decoder_gate_conv_7 = torch.tanh(self.fine_decoder_gate_conv7(fine_decoder_gate_conv_7))
fine_decoder_conv_7 = fine_decoder_gate_conv_7 * fine_decoder_conv_6
fine_decoder_conv_7 = self.fine_decoder_conv7(fine_decoder_conv_7)
fine_mask_out = fine_decoder_conv_7.permute(0,3,1,2)
B,T,C,D = fine_mask_out.size()
fine_mask_out = torch.reshape(fine_mask_out,[B,T, -1])
fine_mask_out = torch.sigmoid(self.fine_dense(fine_mask_out))
fine_mask_out = fine_mask_out.permute(0,2,1)
real_fine_mask_out = fine_mask_out[:, :161]
imag_fine_mask_out = fine_mask_out[:, 161:]
df_inputs = [coarse_enhanced_real,coarse_enhanced_imag]
df_mask = [real_fine_mask_out, imag_fine_mask_out]
fine_enhanced_real, fine_enhanced_imag = self.df(df_inputs,df_mask)
return coarse_enhanced_real, coarse_enhanced_imag, fine_enhanced_real, fine_enhanced_imag, vad_out
def compute_loss(self, coarse_enhanced_real, coarse_enhanced_imag, fine_enhanced_real, fine_enhanced_imag, pref_vad, clean_real, clean_imag, vad_label):
clean_spec_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-8)
vad_ce = F.cross_entropy(pref_vad, vad_label.long())
coarse_real_mae = F.l1_loss(coarse_enhanced_real, clean_real)
coarse_imag_mae = F.l1_loss(coarse_enhanced_imag, clean_imag)
coarse_spec_mags = torch.sqrt(coarse_enhanced_real ** 2 + coarse_enhanced_imag ** 2 + 1e-8)
coarse_spec_mae = F.l1_loss(coarse_spec_mags, clean_spec_mags)
coarse_loss = coarse_real_mae + coarse_imag_mae + coarse_spec_mae
fine_real_mae = F.l1_loss(fine_enhanced_real, clean_real)
fine_imag_mae = F.l1_loss(fine_enhanced_imag, clean_imag)
fine_spec_mags = torch.sqrt(fine_enhanced_real ** 2 + fine_enhanced_imag ** 2 + 1e-8)
fine_spec_mae = F.l1_loss(fine_spec_mags, clean_spec_mags)
fine_loss = fine_real_mae + fine_imag_mae + fine_spec_mae
loss = 0,3*coarse_loss + 0.7*fine_loss + 0.06*vad_ce
return loss
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def print_parameter(model):
model_state_dict = model.state_dict()
for name, tensor in model_state_dict.items():
print('{}\t{}\t{}'.format(name, tensor.shape, tensor.numel()))
if __name__ == '__main__':
time_len = 16000*2
batch = 10
clean = torch.randn(batch, time_len)
mic = torch.randn(batch, time_len)
ref = torch.randn(batch, time_len)
model = TSPNN()
result = get_parameter_number(model)
# print_parameter(model)
print('Number of parameter: \n\t total: {:.2f} M, '
'trainable: {:.2f} M'.format(result['Total'] / 1e6, result['Trainable'] / 1e6))
coarse_real, coarse_imag, fine_real, fine_imag, pred_vad = model(mic,ref)
stft = ConvSTFT(320, 160, 320,'hann', 'complex')
stft_clean = stft(clean)
real_clean = stft_clean[:, :161]
imag_clean = stft_clean[:, 161:]
B, _, T = real_clean.size()
pred_vad = torch.reshape(pred_vad,[B*T,2])
vad_label = torch.reshape(torch.randint(low=0, high=2, size=(B, T)),[B*T,1]).squeeze(1)
loss = model.compute_loss(coarse_real, coarse_imag, fine_real, fine_imag, pred_vad,real_clean,imag_clean,vad_label)
print('Hello world!')
楼上实现基本正确,fine stage的mask大小和df部分有点问题
确实,deepfilter有点问题,跑出来的输出不对,就像改变采样率了一样
@c8x1 @shenbuguanni
非常感谢指点,我跑出来的效果也不太好,能帮忙修改一下吗?对这个模型确实非常感兴趣。
@c8x1 @shenbuguanni 非常感谢指点,我跑出来的效果也不太好,能帮忙修改一下吗?对这个模型确实非常感兴趣。
@zuowanbushiwo 你的问题在于filter部分是不用做重复平移操作的。文章里说了,fine-stage的mask输出维度是2(complex) * F * T * t_widthf_width。 所以filter size应该是 B t_widthf_width F T,这个改完结果就能正确了。
但这个模型的最大问题是coarse-stage并没有那么coarse,因为训练过程没啥干预操作。结果上看stage1 3成的算力干了99%的活,stage2做的事很少,只是让谐波稍微清晰,频点间回声稍微小了一丢丢罢了,我自己训出来的结果,包括看作者提供的结果都是类似的现象。改进点可能是把loss和训练目标做复杂点吧。其次,就实际使用来说抛弃linear filter也有利有弊。
@c8x1 @shenbuguanni 麻烦帮忙看一下 这个对吗?参数对上了,有1.27M , 主要修改了 self.fine_decoder_conv7 的outchannel, 和 DeepFilter 部分 同时在DeepFilter 前加了 chanel_dense 和 freq_dense , 如果不对请 帮忙修改一下。谢谢!
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.signal import get_window
def init_kernels(win_len, fft_len, win_type=None, invers=False):
if win_type == 'None' or win_type is None:
window = np.ones(win_len)
else:
window = get_window(win_type, win_len, fftbins=True) # **0.5
N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
if invers:
kernel = np.linalg.pinv(kernel).T
kernel = kernel * window
kernel = kernel[:, None, :]
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))
class ConvSTFT(nn.Module):
def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
super(ConvSTFT, self).__init__()
if fft_len == None:
self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, _ = init_kernels(win_len, self.fft_len, win_type)
self.register_buffer('weight', kernel)
self.feature_type = feature_type
self.stride = win_inc
self.win_len = win_len
self.dim = self.fft_len
def forward(self, inputs):
if inputs.dim() == 2:
inputs = torch.unsqueeze(inputs, 1)
inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
outputs = F.conv1d(inputs, self.weight, stride=self.stride)
if self.feature_type == 'complex':
return outputs
else:
dim = self.dim // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
mags = torch.sqrt(real ** 2 + imag ** 2)
phase = torch.atan2(imag, real)
return mags, phase
class ConviSTFT(nn.Module):
def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
super(ConviSTFT, self).__init__()
if fft_len == None:
self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, window = init_kernels(win_len, self.fft_len, win_type, invers=True)
self.register_buffer('weight', kernel)
self.feature_type = feature_type
self.win_type = win_type
self.win_len = win_len
self.stride = win_inc
self.dim = self.fft_len
self.register_buffer('window', window)
self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
def forward(self, inputs, phase=None):
if phase is not None:
real = inputs * torch.cos(phase)
imag = inputs * torch.sin(phase)
inputs = torch.cat([real, imag], 1)
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
outputs = outputs / (coff + 1e-8)
outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]
return outputs
class CausalConv(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride):
super(CausalConv, self).__init__()
self.stride = stride
self.kernel_size = kernel_size
self.out_ch = out_ch
self.in_ch = in_ch
self.left_pad = kernel_size[1] - 1
padding = (kernel_size[0] // 2, self.left_pad)
self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size, stride=stride,
padding=padding)
self.norm = nn.BatchNorm2d(out_ch)
self.activation = nn.PReLU()
def forward(self, x):
B, C, F, T = x.size()
out = self.conv(x)[..., :T]
out = self.norm(out)
out = self.activation(out)
return out
class CausalTransConv(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, output_padding = (0,0)):
super(CausalTransConv, self).__init__()
padding = (kernel_size[0] // 2, 0)
self.trans_conv = nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=output_padding)
self.norm = nn.BatchNorm2d(out_ch)
self.activation = nn.PReLU()
def forward(self, x):
B, C, F, T = x.size()
out = self.trans_conv(x)[..., :T]
out = self.norm(out)
out = self.activation(out)
return out
class DeepFilter(nn.Module):
def __init__(self, F_neighbors, T_pastframes, T_lookahead):
super(DeepFilter, self).__init__()
self.F_neighbors = F_neighbors
self.T_pastframes = T_pastframes
self.T_lookahead = T_lookahead
assert(self.T_pastframes>=self.T_lookahead)
t_width = self.T_pastframes + self.T_lookahead +1
f_width = self.F_neighbors*2+1
kernel = torch.eye(t_width*f_width)
self.register_buffer('kernel', torch.reshape(kernel, [t_width*f_width, 1, f_width, t_width]))
def forward(self, inputs, filters_r, filters_i):
'''
inputs is [real, imag]: [ [B,F,T], [B,F,T] ]
'''
chunked_inputs = torch.cat(inputs,0)[:,None]
pad = (self.T_pastframes - self.T_lookahead, 0, 0, 0)
conv_pad = nn.ConstantPad2d(pad, 0.0)
chunked_inputs = conv_pad(chunked_inputs)
chunked_inputs = F.conv2d(
chunked_inputs,
self.kernel,
padding= [self.F_neighbors, self.T_lookahead],
)
inputs_r, inputs_i = torch.chunk(chunked_inputs, 2, 0)
outputs_r = inputs_r*filters_r - inputs_i*filters_i
outputs_i = inputs_r*filters_i + inputs_r*filters_i
outputs_r = torch.sum(outputs_r, 1)
outputs_i = torch.sum(outputs_i, 1)
return outputs_r, outputs_i
class FTGRU(torch.nn.Module):
def __init__(self, hidden_ch_F, hidden_ch_T, input_ch):
super(FTGRU, self).__init__()
self.input_size = input_ch
self.hidden_F = hidden_ch_F
self.hidden_T = hidden_ch_T
self.f_gru0 = torch.nn.GRU(input_size=self.input_size, hidden_size= self.hidden_F // 2, bidirectional=True,
batch_first=True)
self.t_gru0 = torch.nn.GRU(input_size=self.hidden_F, hidden_size=self.hidden_T, batch_first=True)
self.norm0 = nn.BatchNorm2d(self.hidden_T)
self.activation0 = nn.PReLU()
self.f_gru1 = torch.nn.GRU(input_size=self.hidden_T, hidden_size= self.hidden_F // 2, bidirectional=True,
batch_first=True)
self.t_gru1 = torch.nn.GRU(input_size=self.hidden_F, hidden_size=input_ch, batch_first=True)
self.norm1 = nn.BatchNorm2d(input_ch)
self.activation1 = nn.PReLU()
def forward(self, x):
"""
:param x: B,C,F,T
:return:
"""
B, C, F, T = x.size()
x = x.permute(0, 3, 2, 1) # B,T,F,C
input0 = torch.reshape(x, [B * T, F, C])
f_gru0_out, _ = self.f_gru0(input0)
f_gru0_out = torch.reshape(f_gru0_out, [B, T, F, -1])
t_gru0_in = f_gru0_out.permute(0, 2, 1, 3)
t_gru0_in = torch.reshape(t_gru0_in, [B * F, T, -1])
t_gru0_out, _ = self.t_gru0(t_gru0_in)
t_gru0_out = torch.reshape(t_gru0_out, [B, F, T, -1])
norm0_in = t_gru0_out.permute(0, 3, 1, 2)
norm0_out = self.norm0(norm0_in)
activation0_out = self.activation0(norm0_out) #B,C,F,T
x = activation0_out.permute(0, 3, 2, 1) # B,T,F,C
input1 = torch.reshape(x, [B * T, F, -1])
f_gru1_out, _ = self.f_gru1(input1)
f_gru1_out = torch.reshape(f_gru1_out, [B, T, F, -1])
t_gru1_in = f_gru1_out.permute(0, 2, 1, 3)
t_gru1_in = torch.reshape(t_gru1_in, [B * F, T, -1])
t_gru1_out, _ = self.t_gru1(t_gru1_in)
t_gru1_out = torch.reshape(t_gru1_out, [B, F, T, C])
norm1_in = t_gru1_out.permute(0, 3, 1, 2)
norm1_out = self.norm1(norm1_in)
activation1_out = self.activation1(norm1_out)
return activation1_out
class VAD(nn.Module):
def __init__(self):
super(VAD, self).__init__()
self.conv0 = CausalConv(32,16, (1,1),(1,1))
self.f_gru = torch.nn.GRU(input_size=16, hidden_size= 8, bidirectional=True, batch_first=True)
self.conv1d_0 = nn.Conv1d(16,16,kernel_size=1,stride=1, bias=False)
self.norm_0 = nn.BatchNorm1d(16)
self.activation_0 = nn.PReLU()
self.conv1d_1 = nn.Conv1d(16,2,kernel_size=1,stride=1, bias=False)
def forward(self, x):
conv0_out = self.conv0(x)
B, C, F, T = conv0_out.size()
conv0_out = conv0_out.permute(0, 3, 2, 1) # B,T,F,C
intra_in = torch.reshape(conv0_out, [B * T, F, C])
_, f_gru_h = self.f_gru(intra_in)
h_0, _, h_2 = f_gru_h.size()
f_gru_h = torch.reshape(f_gru_h.permute(1,0,2),[B,T,h_0,h_2])
f_gru_h = torch.reshape(f_gru_h,[B,T,-1]).permute(0,2,1)
conv1d_0_out = self.conv1d_0(f_gru_h)
norm_out = self.norm_0(conv1d_0_out)
activation_out = self.activation_0(norm_out)
conv1d_1_out = self.conv1d_1(activation_out).permute(0,2,1) #B,T,2
return conv1d_1_out
class TSPNN(nn.Module):
def __init__(self):
super(TSPNN, self).__init__()
self.stft = ConvSTFT(320, 160, 320,'hann', 'complex')
self.istft = ConviSTFT(320, 160, 320,'hann', 'complex')
self.coarse_encoder_conv0 = CausalConv(2,16, (5,1),(1,1))
self.coarse_encoder_conv1 = CausalConv(16,16,(1,5),(1,1))
self.coarse_encoder_conv2 = CausalConv(16,16,(6,5),(2,1))
self.coarse_encoder_conv3 = CausalConv(16,32,(4,3),(2,1))
self.coarse_encoder_conv4 = CausalConv(32,32,(6,5),(2,1))
self.coarse_encoder_conv5 = CausalConv(32,32,(5,3),(2,1))
self.coarse_encoder_conv6 = CausalConv(32,32,(3,5),(2,1))
self.coarse_encoder_conv7 = CausalConv(32,32,(3,3),(1,1))
self.ftgru_coarse = FTGRU(64,64,32)
self.vad = VAD()
self.coarse_decoder_conv0 = CausalTransConv(32,32, (3,3),(1,1))
self.coarse_decoder_conv1 = CausalTransConv(32,32, (3,5),(2,1))
self.coarse_decoder_conv2 = CausalTransConv(32,32, (5,3),(2,1))
self.coarse_decoder_conv3 = CausalTransConv(32,32, (6,5),(2,1),(1,0))
self.coarse_decoder_conv4 = CausalTransConv(32,16, (4,3),(2,1),(1,0))
self.coarse_decoder_conv5 = CausalTransConv(16,16, (6,5),(2,1),(1,0))
self.coarse_decoder_conv6 = CausalTransConv(16,16, (1,5),(1,1))
self.coarse_decoder_conv7 = CausalTransConv(16,2, (5,1),(1,1))
self.coarse_decoder_gate_conv0 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv1 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv2 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv3 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv4 = nn.Conv2d(64,32, (1,1),(1,1))
self.coarse_decoder_gate_conv5 = nn.Conv2d(32,16, (1,1),(1,1))
self.coarse_decoder_gate_conv6 = nn.Conv2d(32,16, (1,1),(1,1))
self.coarse_decoder_gate_conv7 = nn.Conv2d(32,16, (1,1),(1,1))
self.coarse_dense = nn.Linear(in_features=161*2, out_features=161*2)
self.fine_encoder_conv0 = CausalConv(3,16, (5,1),(1,1))
self.fine_encoder_conv1 = CausalConv(16,16,(1,5),(1,1))
self.fine_encoder_conv2 = CausalConv(16,32,(6,5),(2,1))
self.fine_encoder_conv3 = CausalConv(32,32,(4,3),(2,1))
self.fine_encoder_conv4 = CausalConv(32,64,(6,5),(2,1))
self.fine_encoder_conv5 = CausalConv(64,64,(5,3),(2,1))
self.fine_encoder_conv6 = CausalConv(64,64,(3,5),(2,1))
self.fine_encoder_conv7 = CausalConv(64,64,(3,3),(1,1))
self.ftgru_fine = FTGRU(128,128,64)
self.fine_decoder_conv0 = CausalTransConv(64,64, (3,3),(1,1))
self.fine_decoder_conv1 = CausalTransConv(64,64, (3,5),(2,1))
self.fine_decoder_conv2 = CausalTransConv(64,64, (5,3),(2,1))
self.fine_decoder_conv3 = CausalTransConv(64,32, (6,5),(2,1),(1,0))
self.fine_decoder_conv4 = CausalTransConv(32,32, (4,3),(2,1),(1,0))
self.fine_decoder_conv5 = CausalTransConv(32,16, (6,5),(2,1),(1,0))
self.fine_decoder_conv6 = CausalTransConv(16,16, (1,5),(1,1))
self.fine_decoder_conv7 = CausalTransConv(16,70, (5,1),(1,1))
self.fine_decoder_gate_conv0 = nn.Conv2d(128,64, (1,1),(1,1))
self.fine_decoder_gate_conv1 = nn.Conv2d(128,64, (1,1),(1,1))
self.fine_decoder_gate_conv2 = nn.Conv2d(128,64, (1,1),(1,1))
self.fine_decoder_gate_conv3 = nn.Conv2d(128,64, (1,1),(1,1))
self.fine_decoder_gate_conv4 = nn.Conv2d(64,32, (1,1),(1,1))
self.fine_decoder_gate_conv5 = nn.Conv2d(64,32, (1,1),(1,1))
self.fine_decoder_gate_conv6 = nn.Conv2d(32,16, (1,1),(1,1))
self.fine_decoder_gate_conv7 = nn.Conv2d(32,16, (1,1),(1,1))
self.chanel_dense = nn.Linear(in_features=70, out_features=70)
self.freq_dense = nn.Linear(in_features=161*2, out_features=161*2)
self.df = DeepFilter(3,3,1)
def forward(self, mic, ref):
return self.input_forward(mic, ref)
def input_forward(self, mic, ref):
stft_mic = self.stft(mic)
real_mic = stft_mic[:, :161]
imag_mic = stft_mic[:, 161:]
spec_mags_mic = torch.sqrt(real_mic ** 2 + imag_mic ** 2 + 1e-8)
compressed_mags_mic = torch.pow(spec_mags_mic, 0.3)
stft_ref = self.stft(ref)
real_ref = stft_ref[:, :161]
imag_ref = stft_ref[:, 161:]
compressed_mags_ref = torch.pow(real_ref**2 + imag_ref**2 + 1e-8, 0.3*0.5)
coarse_spec_mags = torch.stack([compressed_mags_ref, compressed_mags_mic], dim=1) #(B,2,161,T)
coarse_encoder_conv_0 = self.coarse_encoder_conv0(coarse_spec_mags) #(B,16,161,T)
coarse_encoder_conv_1 = self.coarse_encoder_conv1(coarse_encoder_conv_0) #(B,16,161,T)
coarse_encoder_conv_2 = self.coarse_encoder_conv2(coarse_encoder_conv_1) #(B,16,81,T)
coarse_encoder_conv_3 = self.coarse_encoder_conv3(coarse_encoder_conv_2) #(B,32,41,T)
coarse_encoder_conv_4 = self.coarse_encoder_conv4(coarse_encoder_conv_3) #(B,32,21,T)
coarse_encoder_conv_5 = self.coarse_encoder_conv5(coarse_encoder_conv_4) #(B,32,11,T)
coarse_encoder_conv_6 = self.coarse_encoder_conv6(coarse_encoder_conv_5) #(B,32,6,T)
coarse_encoder_conv_7 = self.coarse_encoder_conv7(coarse_encoder_conv_6) #(B,32,6,T)
ftgru_coarse_out = self.ftgru_coarse(coarse_encoder_conv_7)
vad_out = self.vad(ftgru_coarse_out)
coarse_decoder_gate_conv_0 = torch.cat([ftgru_coarse_out, coarse_encoder_conv_7] , 1)
coarse_decoder_gate_conv_0 = torch.tanh(self.coarse_decoder_gate_conv0(coarse_decoder_gate_conv_0))
coarse_decoder_conv_0 = coarse_decoder_gate_conv_0 * ftgru_coarse_out
coarse_decoder_conv_0 = self.coarse_decoder_conv0(coarse_decoder_conv_0)
coarse_decoder_gate_conv_1 = torch.cat([coarse_decoder_conv_0, coarse_encoder_conv_6] , 1)
coarse_decoder_gate_conv_1 = torch.tanh(self.coarse_decoder_gate_conv1(coarse_decoder_gate_conv_1))
coarse_decoder_conv_1 = coarse_decoder_gate_conv_1 * coarse_decoder_conv_0
coarse_decoder_conv_1 = self.coarse_decoder_conv1(coarse_decoder_conv_1)
coarse_decoder_gate_conv_2 = torch.cat([coarse_decoder_conv_1, coarse_encoder_conv_5] , 1)
coarse_decoder_gate_conv_2 = torch.tanh(self.coarse_decoder_gate_conv2(coarse_decoder_gate_conv_2))
coarse_decoder_conv_2 = coarse_decoder_gate_conv_2 * coarse_decoder_conv_1
coarse_decoder_conv_2 = self.coarse_decoder_conv2(coarse_decoder_conv_2)
coarse_decoder_gate_conv_3 = torch.cat([coarse_decoder_conv_2, coarse_encoder_conv_4] , 1)
coarse_decoder_gate_conv_3 = torch.tanh(self.coarse_decoder_gate_conv3(coarse_decoder_gate_conv_3))
coarse_decoder_conv_3 = coarse_decoder_gate_conv_3 * coarse_decoder_conv_2
coarse_decoder_conv_3 = self.coarse_decoder_conv3(coarse_decoder_conv_3)
coarse_decoder_gate_conv_4 = torch.cat([coarse_decoder_conv_3, coarse_encoder_conv_3] , 1)
coarse_decoder_gate_conv_4 = torch.tanh(self.coarse_decoder_gate_conv4(coarse_decoder_gate_conv_4))
coarse_decoder_conv_4 = coarse_decoder_gate_conv_4 * coarse_decoder_conv_3
coarse_decoder_conv_4 = self.coarse_decoder_conv4(coarse_decoder_conv_4)
coarse_decoder_gate_conv_5 = torch.cat([coarse_decoder_conv_4, coarse_encoder_conv_2] , 1)
coarse_decoder_gate_conv_5 = torch.tanh(self.coarse_decoder_gate_conv5(coarse_decoder_gate_conv_5))
coarse_decoder_conv_5 = coarse_decoder_gate_conv_5 * coarse_decoder_conv_4
coarse_decoder_conv_5 = self.coarse_decoder_conv5(coarse_decoder_conv_5)
coarse_decoder_gate_conv_6 = torch.cat([coarse_decoder_conv_5, coarse_encoder_conv_1] , 1)
coarse_decoder_gate_conv_6 = torch.tanh(self.coarse_decoder_gate_conv6(coarse_decoder_gate_conv_6))
coarse_decoder_conv_6 = coarse_decoder_gate_conv_6 * coarse_decoder_conv_5
coarse_decoder_conv_6 = self.coarse_decoder_conv6(coarse_decoder_conv_6)
coarse_decoder_gate_conv_7 = torch.cat([coarse_decoder_conv_6, coarse_encoder_conv_0] , 1)
coarse_decoder_gate_conv_7 = torch.tanh(self.coarse_decoder_gate_conv7(coarse_decoder_gate_conv_7))
coarse_decoder_conv_7 = coarse_decoder_gate_conv_7 * coarse_decoder_conv_6
coarse_decoder_conv_7 = self.coarse_decoder_conv7(coarse_decoder_conv_7)
coarse_mask_out = coarse_decoder_conv_7.permute(0,3,1,2)
B,T,C,D = coarse_mask_out.size()
coarse_mask_out = torch.reshape(coarse_mask_out,[B,T, -1])
coarse_mask_out = torch.sigmoid(self.coarse_dense(coarse_mask_out))
coarse_mask_out = coarse_mask_out.permute(0,2,1)
real_coarse_mask_out = coarse_mask_out[:, :161]
imag_coarse_mask_out = coarse_mask_out[:, 161:]
coarse_enhanced_real = real_coarse_mask_out * real_mic - imag_coarse_mask_out * imag_mic
coarse_enhanced_imag = real_coarse_mask_out * imag_mic + imag_coarse_mask_out * real_mic
spec_mags_coarse_out = torch.sqrt(coarse_enhanced_real ** 2 + coarse_enhanced_imag ** 2 + 1e-8)
compressed_coarse_mags = torch.pow(spec_mags_coarse_out, 0.3)
fine_spec_mags = torch.stack([compressed_mags_ref, compressed_coarse_mags, compressed_mags_mic], dim=1)
fine_encoder_conv_0 = self.fine_encoder_conv0(fine_spec_mags)
fine_encoder_conv_1 = self.fine_encoder_conv1(fine_encoder_conv_0)
fine_encoder_conv_2 = self.fine_encoder_conv2(fine_encoder_conv_1)
fine_encoder_conv_3 = self.fine_encoder_conv3(fine_encoder_conv_2)
fine_encoder_conv_4 = self.fine_encoder_conv4(fine_encoder_conv_3)
fine_encoder_conv_5 = self.fine_encoder_conv5(fine_encoder_conv_4)
fine_encoder_conv_6 = self.fine_encoder_conv6(fine_encoder_conv_5)
fine_encoder_conv_7 = self.fine_encoder_conv7(fine_encoder_conv_6)
ftgru_fine_out = self.ftgru_fine(fine_encoder_conv_7)
fine_decoder_gate_conv_0 = torch.cat([ftgru_fine_out, fine_encoder_conv_7] , 1)
fine_decoder_gate_conv_0 = torch.tanh(self.fine_decoder_gate_conv0(fine_decoder_gate_conv_0))
fine_decoder_conv_0 = fine_decoder_gate_conv_0 * ftgru_fine_out
fine_decoder_conv_0 = self.fine_decoder_conv0(fine_decoder_conv_0)
fine_decoder_gate_conv_1 = torch.cat([fine_decoder_conv_0, fine_encoder_conv_6] , 1)
fine_decoder_gate_conv_1 = torch.tanh(self.fine_decoder_gate_conv1(fine_decoder_gate_conv_1))
fine_decoder_conv_1 = fine_decoder_gate_conv_1 * fine_decoder_conv_0
fine_decoder_conv_1 = self.fine_decoder_conv1(fine_decoder_conv_1)
fine_decoder_gate_conv_2 = torch.cat([fine_decoder_conv_1, fine_encoder_conv_5] , 1)
fine_decoder_gate_conv_2 = torch.tanh(self.fine_decoder_gate_conv2(fine_decoder_gate_conv_2))
fine_decoder_conv_2 = fine_decoder_gate_conv_2 * fine_decoder_conv_1
fine_decoder_conv_2 = self.fine_decoder_conv2(fine_decoder_conv_2)
fine_decoder_gate_conv_3 = torch.cat([fine_decoder_conv_2, fine_encoder_conv_4] , 1)
fine_decoder_gate_conv_3 = torch.tanh(self.fine_decoder_gate_conv3(fine_decoder_gate_conv_3))
fine_decoder_conv_3 = fine_decoder_gate_conv_3 * fine_decoder_conv_2
fine_decoder_conv_3 = self.fine_decoder_conv3(fine_decoder_conv_3)
fine_decoder_gate_conv_4 = torch.cat([fine_decoder_conv_3, fine_encoder_conv_3] , 1)
fine_decoder_gate_conv_4 = torch.tanh(self.fine_decoder_gate_conv4(fine_decoder_gate_conv_4))
fine_decoder_conv_4 = fine_decoder_gate_conv_4 * fine_decoder_conv_3
fine_decoder_conv_4 = self.fine_decoder_conv4(fine_decoder_conv_4)
fine_decoder_gate_conv_5 = torch.cat([fine_decoder_conv_4, fine_encoder_conv_2] , 1)
fine_decoder_gate_conv_5 = torch.tanh(self.fine_decoder_gate_conv5(fine_decoder_gate_conv_5))
fine_decoder_conv_5 = fine_decoder_gate_conv_5 * fine_decoder_conv_4
fine_decoder_conv_5 = self.fine_decoder_conv5(fine_decoder_conv_5)
fine_decoder_gate_conv_6 = torch.cat([fine_decoder_conv_5, fine_encoder_conv_1] , 1)
fine_decoder_gate_conv_6 = torch.tanh(self.fine_decoder_gate_conv6(fine_decoder_gate_conv_6))
fine_decoder_conv_6 = fine_decoder_gate_conv_6 * fine_decoder_conv_5
fine_decoder_conv_6 = self.fine_decoder_conv6(fine_decoder_conv_6)
fine_decoder_gate_conv_7 = torch.cat([fine_decoder_conv_6, fine_encoder_conv_0] , 1)
fine_decoder_gate_conv_7 = torch.tanh(self.fine_decoder_gate_conv7(fine_decoder_gate_conv_7))
fine_decoder_conv_7 = fine_decoder_gate_conv_7 * fine_decoder_conv_6
fine_decoder_conv_7 = self.fine_decoder_conv7(fine_decoder_conv_7)
chanel_dense_in = fine_decoder_conv_7.permute(0,3,2,1)
chanel_dense_in = self.chanel_dense(chanel_dense_in)
B,T,D,C = chanel_dense_in.size()
freq_dense_in = chanel_dense_in.reshape(B,T,D,35,2)
freq_dense_in = freq_dense_in.permute(0,3,1,2,4)
freq_dense_in = freq_dense_in.reshape(B,35,T,-1)
fine_mask_out = torch.sigmoid(self.freq_dense(freq_dense_in))
real_fine_mask_out = fine_mask_out[..., :161].permute(0,1,3,2)
imag_fine_mask_out = fine_mask_out[..., 161:].permute(0,1,3,2)
df_inputs = [coarse_enhanced_real,coarse_enhanced_imag]
fine_enhanced_real, fine_enhanced_imag = self.df(df_inputs,real_fine_mask_out, imag_fine_mask_out)
return coarse_enhanced_real, coarse_enhanced_imag, fine_enhanced_real, fine_enhanced_imag, vad_out
def compute_loss(self,coarse_enhanced_real, coarse_enhanced_imag, fine_enhanced_real, fine_enhanced_imag, pref_vad, clean_real, clean_imag, vad_label):
clean_spec_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-8)
vad_ce = F.cross_entropy(pref_vad, vad_label.long())
coarse_real_mae = F.l1_loss(coarse_enhanced_real, clean_real)
coarse_imag_mae = F.l1_loss(coarse_enhanced_imag, clean_imag)
coarse_spec_mags = torch.sqrt(coarse_enhanced_real ** 2 + coarse_enhanced_imag ** 2 + 1e-8)
coarse_spec_mae = F.l1_loss(coarse_spec_mags, clean_spec_mags)
coarse_loss = coarse_real_mae + coarse_imag_mae + coarse_spec_mae
fine_real_mae = F.l1_loss(fine_enhanced_real, clean_real)
fine_imag_mae = F.l1_loss(fine_enhanced_imag, clean_imag)
fine_spec_mags = torch.sqrt(fine_enhanced_real ** 2 + fine_enhanced_imag ** 2 + 1e-8)
fine_spec_mae = F.l1_loss(fine_spec_mags, clean_spec_mags)
fine_loss = fine_real_mae + fine_imag_mae + fine_spec_mae
loss = 0,3*coarse_loss + 0.7*fine_loss + 0.06*vad_ce
return loss
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def print_parameter(model):
model_state_dict = model.state_dict()
for name, tensor in model_state_dict.items():
print('{}\t{}\t{}'.format(name, tensor.shape, tensor.numel()))
if __name__ == '__main__':
time_len = 16000*2
batch = 10
clean = torch.randn(batch, time_len)
mic = torch.randn(batch, time_len)
ref = torch.randn(batch, time_len)
model = TSPNN()
result = get_parameter_number(model)
# print_parameter(model)
print('Number of parameter: \n\t total: {:.2f} M, '
'trainable: {:.2f} M'.format(result['Total'] / 1e6, result['Trainable'] / 1e6))
coarse_real, coarse_imag, fine_real, fine_imag, pred_vad = model(mic,ref)
stft = ConvSTFT(320, 160, 320,'hann', 'complex')
stft_clean = stft(clean)
real_clean = stft_clean[:, :161]
imag_clean = stft_clean[:, 161:]
B, _, T = real_clean.size()
pred_vad = torch.reshape(pred_vad,[B*T,2])
vad_label = torch.reshape(torch.randint(low=0, high=2, size=(B, T)),[B*T,1]).squeeze(1)
loss = model.compute_loss(coarse_real, coarse_imag, fine_real, fine_imag, pred_vad,real_clean,imag_clean,vad_label)
print('Hello world!')
厉害 👍 不过如果想用作实时推理的话,看起来padding似乎有点问题,也不是因果系统。
有个问题请教大家,预测的complex ideal ratio mask的理论真值范围是(-inf, inf),而模型中使用sigmoid做激活只能建模(0,1),不会引起很大的偏差吗?
有个问题请教大家,预测的complex ideal ratio mask的理论真值范围是(-inf, inf),而模型中使用sigmoid做激活只能建模(0,1),不会引起很大的偏差吗?
@nicriverhoo hi,最近还在关注这篇work吗~