Cannot use `whisper-*.en` models in bumblebee
John-Goff opened this issue · comments
With any combination of whisper-medium.en
, whisper-small.en
, or whisper-tiny.en
the following code does not work:
Mix.install([
{:bumblebee, "~> 0.4.2"},
{:nx, "~> 0.6.4"},
{:exla, "~> 0.6.4"},
{:axon, "~> 0.6.0"}
])
{[{:file, path}], _, _} =
OptionParser.parse(System.argv(), strict: [file: :string], aliases: [f: :file])
{:ok, model_info} = Bumblebee.load_model({:hf, "openai/whisper-tiny.en"})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, "openai/whisper-tiny.en"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/whisper-tiny.en"})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, "openai/whisper-tiny.en"})
generation_config = Bumblebee.configure(generation_config, max_new_tokens: 100)
serving =
Bumblebee.Audio.speech_to_text_whisper(model_info, featurizer, tokenizer, generation_config,
chunk_num_seconds: 30,
timestamps: :segments,
stream: true,
compile: [batch_size: 4],
defn_options: [compiler: EXLA]
)
Nx.Serving.run(serving, {:file, Path.expand(path)})
|> Enum.reduce("", fn chunk, acc -> acc <> chunk.text end)
|> IO.puts()
Fails with error
** (RuntimeError) invalid task :transcribe, expected one of:
(bumblebee 0.4.2) lib/bumblebee/audio/speech_to_text_whisper.ex:210: Bumblebee.Audio.SpeechToTextWhisper.forced_token_ids/2
(bumblebee 0.4.2) lib/bumblebee/audio/speech_to_text_whisper.ex:156: Bumblebee.Audio.SpeechToTextWhisper.generate_opts/2
(bumblebee 0.4.2) lib/bumblebee/audio/speech_to_text_whisper.ex:50: Bumblebee.Audio.SpeechToTextWhisper.speech_to_text_whisper/5
transcribe.exs:18: (file)
Changing the model to whisper-medium
, whisper-small
, or whisper-tiny
will work properly.
distil-whisper/distil-small.en
and related models also fails with the same error.
For these you need to set task: nil
explicitly. Bumblebee main has an improved error message, it just hasn't been released yet :)
the generation config does not have any tasks defined. If you are dealing with a monolingual model, set :task to nil. Otherwise you may need to update generation_config.extra_config.task_to_token_id