clovaai / voxceleb_trainer

In defence of metric learning for speaker recognition

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Cannot get 1.1771 eer (a large gap) with the pretrained model in README

cc98gg opened this issue · comments

In the README, it mentions that:

A larger model trained with online data augmentation, described in [2], can be downloaded from here.

The following script should return: EER 1.1771.

python ./trainSpeakerNet.py --eval --model ResNetSE34V2 --log_input True --encoder_type ASP --n_mels 64 --trainfunc softmaxproto --save_path exps/test --eval_frames 400  --initial_model baseline_v2_ap.model

However, with the pretrained baseline_v2_ap model and same script, I can only get EER 1.7073, MinDCF 0.12625. Is there a mistake in your model or code?

I also found the same problem on the master branch, but I have successfully reproduced this result a year or two ago.
I compared the two codes and found some differences.
First of all, you should switch the version of pytorch to 1.7.1, because I have successfully reproduced in this version of pytorch, I compared pytorch 1.7.1 and 1.10.0, the torch.nn.functional.pairwise_distance used to calculate the score function has changed.
QQ20220301-121829
In the second step, you can try to add the following line of code to the forward method in ResNetSE34V2.py, because the wavfile module is used to read the speech instead of the soundfile module in the previous code. When read with the wavfile module, the voice is not normalized.

    def forward(self, x):
        x *= 32768.0
        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=False):
                x = self.torchfb(x)+1e-6
                if self.log_input: x = x.log()
                x = self.instancenorm(x).unsqueeze(1)

As in the comment of @chmod740, it seems to be a BC (breaking changes) issue depending on the torch version.

From the torch version 1.10.0, the torch.nn.functional.pairwise_distance function computes the pairwise distance between vectors. (Unlike the batchwise pairwise distance as in the version 1.7.0 ~ 1.9.1)

It seems to be compatible if you use torch.cdist function, for instance,

223    #dist = F.pairwise_distance(ref_feat.unsqueeze(-1), com_feat.unsqueeze(-1).transpose(0,2)).detach().cpu().numpy();
224    dist = torch.cdist(ref_feat.reshape((num_eval, -1)), com_feat.reshape((num_eval, -1))).detach().cpu().numpy();

dist = F.pairwise_distance(ref_feat.unsqueeze(-1), com_feat.unsqueeze(-1).transpose(0,2)).detach().cpu().numpy();


Result using torch.nn.functional.pairwise_distance() in pytorch version 1.10
image

Result using torch.cdist() in pytorch version 1.10
image

Thank you @msh9184.
I have tried this. It works for the voxceleb test list, but when I use another test set, it through an error as follows.

image

But this test set is working correctly with the,
dist = F.pairwise_distance(ref_feat.unsqueeze(-1), com_feat.unsqueeze(-1).transpose(0,2)).detach().cpu().numpy();

Do you have any idea why this happens? @msh9184 and @chmod740

I guess you set num_eval to 10, but your embedding shape is (512) (not [10, 512]).

Can you try either (1) setting --num_eval 1, or (2) changing the cdist input shape, e.g., torch.cdist(ref_feat.unsqueeze(0), com_feat.unsqueeze(0))?

It will correctly work as long as the inputs to torch.cdist have the shape of [num_eval, nOut] (e.g., [10, 512] or [1, 512]).

I guess you set num_eval to 10, but your embedding shape is (512) (not [10, 512]).

Can you try either (1) setting --num_eval 1, or (2) changing the cdist input shape, e.g., torch.cdist(ref_feat.unsqueeze(0), com_feat.unsqueeze(0))?

It will correctly work as long as the inputs to torch.cdist have the shape of [num_eval, nOut] (e.g., [10, 512] or [1, 512]).
I guess you set num_eval to 10, but your embedding shape is (512) (not [10, 512]).

Can you try either (1) setting --num_eval 1, or (2) changing the cdist input shape, e.g., torch.cdist(ref_feat.unsqueeze(0), com_feat.unsqueeze(0))?

It will correctly work as long as the inputs to torch.cdist have the shape of [num_eval, nOut] (e.g., [10, 512] or [1, 512]).

Thank you @msh9184 . this solved my problem. Also, I will do an investigation on a performance difference when we take subsamples (num_val 10) 10 and with one subsample.

Just express if you have any comments on that.

You can refer to Test Time Augmentation in this paper.
In my opinion, TTA can be used for a compensation technique for scoring between utterances with unbalanced durations.

Dear all in this thread,

I tried to replicate the larger baseline model (baseline_v2_ap.model) as described in the paper "Clova Baseline System for the VoxCeleb Speaker Recognition Challenge 2020". I tried to incorporate every minute detail as mentioned in the paper. This is the config file that I wrote:

model: 'ResNetSE34V2'
log_input: True
encoder_type: 'ASP'
train_path: 'data/voxceleb2'
test_path: 'data/voxceleb1'
train_list: 'data/train_list.txt'
test_list: 'data/test_list.txt'
trainfunc: 'softmaxproto'
save_path: 'exps/exp_softmaxproto_v2'
max_frames: 200
eval_frames: 400
n_mels: 64
nClasses: 5994
batch_size: 75
nPerSpeaker: 2
nOut: 512
batch_size: 75
lr: 0.001
lr_decay: 0.75 #Because the paper mentions that lr is reduced by 25% every 3 epochs.
nDataLoaderThread: 2
augment: True
test_interval: 3
weight_decay: 0.00005
distributed: True
mixedprec: True
max_epoch: 36

I am not getting the EER of 1.177 as claimed after 36 epochs. I also tried batch size of 150 speakers and nPerSpeaker 2 for 500 epochs, which also did not give me the claimed performance.

With batch_size=75, and nPerSpeaker=2, I am getting an EER of 1.97% after 36 epochs on the Voxceleb1 test set.
With batch_size=150, nPerSpeaker=2, I am getting an EER of 1.91% after 36 epochs, and this config achieves a minimum EER of 1.77% when trained for up to 500 epochs. Can anyone please let me know what I may be missing here?

When I evaluate the pre-trained baseline model (baseline_v2_ap.model) available in the link, I am getting the claimed result of 1.177% EER.
I am using pytorch 1.8.2 with cudatoolkit 11.1 on a DGX machine with 8x32GB Tesla V100 GPUs.

Regards,
Shreyas Ramoji (LEAP Lab, IISc)

@iiscleap I think you can refer to #61 comment.
It seems that the author used a batch size of 180 instead of 75 for training the released model (baseline_v2_ap.model).

@iiscleap I think you can refer to #61 comment. It seems that the author used a batch size of 180 instead of 75 for training the released model (baseline_v2_ap.model).

Thank you @msh9184 ! Will try this out and get back.

@iiscleap Here are my result and configuration: scores.txt, config.txt
I followed the recipe of this paper using pytorch 1.10.1 with cuda version 11.5 on the four 24GB NVIDIA 3090Ti GPUs.

Hi all,

Please correct me if I am wrong. I dug up some older versions of this repository and realized that at the time when the report "Clova Baseline System for the VoxCeleb Speaker Recognition Challenge 2020" was written, the distributed sampler that was used was quite different from the current version. Earlier, one epoch of distributed training with 8 GPUs would mean that each GPU sees all of the training samples, whereas, in the current version, when the distributed flag is True, the entire dataset is split into 8 parts when we use 8 GPUs. This means the learning rate decay interval of 3 epochs in the older version is equivalent to 3x8=24 epochs in the current version. And the 36 epochs in the older version is supposed to be 36x8=288 epochs.

Did I get this right?

-Shreyas Ramoji

Thanks for the report and a valid fix.
I'm reflecting this in #154