关于ERes2Net_VOX模型的效果问题
Yaodada12 opened this issue · comments
调整了一下speakerlab/bin/infer_sv.py的推理代码,使用的模型是https://www.modelscope.cn/models/iic/speech_eres2net_large_sv_en_voxceleb_16k/files,测试时听起来非常像的两个音频,其说话人相似度偏低。以0_s.wav和0_t.wav为例,取前3s的音频测试的相似度最高,为0.1775,但实际听起来两个人的相似度还是非常高的,所以不知道时哪里有问题。修改后的代码如下:
import os, sys
sys.path.append(os.getcwd())
import torch
import torchaudio
from speakerlab.process.processor import FBank
from speakerlab.utils.builder import dynamic_import
ERes2Net_VOX = {
'obj': 'speakerlab.models.eres2net.ERes2Net.ERes2Net',
'args': {
'feat_dim': 80,
'embedding_size': 192,
'm_channels': 64,
},
}
class SpeakerVerification:
def __init__(self, device='cuda:1') -> None:
self.device = device
self.feature_extractor = FBank(80, sample_rate=16000, mean_nor=True)
self.embedding_model = self.load_model()
def load_model(self):
pretrained_state = torch.load("speakerlab/pretrained/pretrained_eres2net.ckpt", map_location='cpu')
# load model
model = ERes2Net_VOX
embedding_model = dynamic_import(model['obj'])(**model['args'])
embedding_model.load_state_dict(pretrained_state)
embedding_model.to(self.device)
embedding_model.eval()
return embedding_model
def compute_embedding(self, wav):
# compute feat
feat = self.feature_extractor(wav).unsqueeze(0).to(self.device)
# compute embedding
with torch.no_grad():
embedding = self.embedding_model(feat)#.detach().cpu().numpy()
return embedding
def run(self, wav_file1, wav_file2):
if type(wav_file1) == str or type(wav_file2) == str:
wav1, _ = torchaudio.load(wav_file1)
if wav1.shape[0] > 1:
wav1 = wav1[0, :].unsqueeze(0)
wav2, _ = torchaudio.load(wav_file2)
if wav2.shape[0] > 1:
wav2 = wav2[0, :].unsqueeze(0)
else:
wav1 = wav_file1
wav2 = wav_file2
if wav1.shape[1] > 48000:
wav1 = wav1[:, :48000]
if wav2.shape[1] > 48000:
wav2 = wav2[:, :48000]
# extract embeddings
print(f'[INFO]: Extracting embeddings...')
embedding1 = self.compute_embedding(wav1)
embedding2 = self.compute_embedding(wav2)
similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
scores = similarity(embedding1, embedding2)
print('[INFO]: The similarity score between two input wavs is %.4f' % scores)
return scores
def test(wav_file1, wav_file2):
sv = SpeakerVerification()
sv.run(wav_file1, wav_file2)
if __name__ == '__main__':
for i in range(10):
test('0_s.wav', '0_t.wav')
您使用的这个模型是学术模型,针对VoxCeleb进行训练测试。建议您这边更换通用模型进行测试:ERes2NetV2_COMMON or ERes2Net_COMMON or CAMPPLUS_COMMON.
CAMPPLUS_COMMON
多谢,CAMPPLUS_COMMON模型在英文数据上可用。
上述三个模型皆具有通用性,你可以依次尝试比较哈。
ERes2NetV2_COMMON
ok,性能上来看ERes2NetV2_COMMON模型效果和速度是不是都是最好的。