tazz4843 / whisper-rs

Rust bindings to https://github.com/ggerganov/whisper.cpp

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

re-export whisper-rs-sys to export raw api

bruceunx opened this issue · comments

when I implement realtime stream as following:

first, define the callback for set_progress_callback

static mut COUNT: i32 = 0;

unsafe extern "C" fn callback(
    _ctx: *mut whisper_rs_sys::whisper_context,
    state: *mut whisper_rs_sys::whisper_state,
    _progress: std::os::raw::c_int,
    _user_data: *mut std::ffi::c_void,
) {
    let num_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
    for i in COUNT..num_segments {
        let ret = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
        let c_str = CStr::from_ptr(ret);
        let r_str = c_str.to_str().unwrap();
        let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
        let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
        println!("[{} -> {}]: {}", t0, t1, r_str);
    }

    COUNT = num_segments;
}

second, add callback to params

    unsafe {
        params.set_progress_callback(Some(callback));
    }

while using this callback in third application, we should add whisper-rs and whisper-rs-sys as dependencies, and now these dependencies is a little confusing, better re-export whisper-rs-sys in whisper-rs.

full example like this:

#![allow(clippy::uninlined_format_args)]
use std::ffi::CStr;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};

static mut COUNT: i32 = 0;

unsafe extern "C" fn callback(
    _ctx: *mut whisper_rs_sys::whisper_context,
    state: *mut whisper_rs_sys::whisper_state,
    _progress: std::os::raw::c_int,
    _user_data: *mut std::ffi::c_void,
) {
    let num_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
    for i in COUNT..num_segments {
        let ret = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
        let c_str = CStr::from_ptr(ret);
        let r_str = c_str.to_str().unwrap();
        let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
        let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
        println!("[{} -> {}]: {}", t0, t1, r_str);
    }

    COUNT = num_segments;
}

fn main() {
    let path_to_model = std::env::args().nth(1).unwrap();
    let path_to_wav = std::env::args().nth(2).unwrap();

    // load a context and model
    println!("{}", &path_to_model);
    let ctx: WhisperContext =
        WhisperContext::new_with_params(&path_to_model, WhisperContextParameters::default())
            .expect("failed to load model");

    let mut state = ctx.create_state().expect("failed to create state");
    // create a params object
    let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });

    params.set_language(Some("en"));
    params.set_n_threads(3);
    unsafe {
        params.set_progress_callback(Some(callback));
    }

    // Open the audio file.
    let mut reader = hound::WavReader::open(&path_to_wav).expect("failed to open file");

    // Convert the audio to floating point samples.
    let sample = reader
        .samples::<i16>()
        .map(|s| s.expect("invalid sample"))
        .collect::<Vec<_>>();

    let mut audio = vec![0.0f32; sample.len()];

    whisper_rs::convert_integer_to_float_audio(&sample, &mut audio)
        .expect("failed to convert audio");

    state
        .full(params, &audio[..1200000])
        .expect("failed to run model");
}
commented

whisper-rs-sys doesn't follow the same versioning system as whisper-rs, and bumping whisper-rs-sys to a breaking version in a patch release of whisper-rs has happened before. As such, I added this locked behind the raw-api feature flag, and noted in the readme that enabling this no longer guarantees semver compliance. Done in commit 7714a10