huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unable to reproduce results from the paper

MLMonkATGY opened this issue · comments

Can the exact code from be used to reproduce the results from Table 16 ? I tried to benchmark distill-whisper-v2 on distil-whisper/common_voice_13_0 dataset and found the WER is a few percent higher than what was reported in the paper?


Hey @MLMonkATGY! Could you share the arguments you're passing to so that I can reproduce locally? I believe this is because we are using the BasicNormalizer in the PyTorch script

normalizer = (
BasicTextNormalizer() if data_args.language is not None
else EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)

Whereas in the original Flax scripts, we always used the EnglishNormalizer:

normalizer = EnglishTextNormalizer(tokenizer.english_spelling_normalizer)

You should be able to reproduce the results one-to-one if you use the Flax script. I'll also update the PyTorch script to use the EnglishNormalizer if the language used is English!

I used the following arguments for

python \ --model_name_or_path "distil-whisper/distil-large-v2" \ --dataset_name distil-whisper/common_voice_13_0 \ --dataset_config_name en \ --dataset_split_name test \ --text_column_name text \ --batch_size 128 \ --dtype "bfloat16" \ --generation_max_length 256 \ --language "en" \ --attn_implementation "flash_attention_2" \ --streaming True

Hey @MLMonkATGY, after merging #132, I evaluated the model with the following:


python \
    --model_name_or_path "distil-whisper/distil-large-v2" \
    --dataset_name "distil-whisper/common_voice_13_0" \
    --dataset_config_name "en" \
    --dataset_split_name "test" \
    --text_column_name "text" \
    --batch_size 128 \
    --dtype "bfloat16" \
    --generation_max_length 256 \
    --language "en" \
    --streaming True

And got a WER of 13.0%:

This is within 0.1% of the 12.9% WER reported in the paper. This 0.1% difference is expected, since the paper WER results are in Flax on TPU, whereas the script is in PyTorch on GPU. There's an inherent difference in how matrix multiplications are implemented in both, giving a subtle difference in results. Note that all WER results from the paper are in Flax, so the comparison between large-v2 and distil-large-v2 is valid. Note also that all RTF values in the paper were computed in PyTorch on GPU, such that they're most applicable to downstream use cases. I hope that helps!

All in all, the PR #132 should now mean that evaluating models in English with the PyTorch script gives WER results that are within 0.1% of the WER results quoted in the paper (using the Flax script flax/

Thanks !