rust-ml / linfa

A Rust machine learning framework.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Exporting & Loading Trained Model?

Bastian1110 opened this issue · comments

Is there a method to save a trained (GaussianNb)model and the load it?

I'm just learning how to use Rust, I just managed to implement a Gaussian Naive Bayes classifier model, is there any way to use the "predict" method without having to train the whole model again?
I know that in libraries like Sklearn you can export them models and then load them in .pkl formats, is there a similar implementation in linfa?

Thank you so much!

The models should implement the Serde trait, so you can serialize them using something like ciborium

I'm having a little problem hehe, I'm trying to serialize the model (I'm using the example of linfa_svm) but I don't know if I'm using the correct syntax since I get the error in the line where I use cbor:

the trait bound `MultiClassModel<ndarray::ArrayBase<ndarray::data_repr::OwnedRepr<f64>, ndarray::dimension::dim::Dim<[usize; 2]>>, usize>: serde::ser::Serialize` is not satisfied
the following other types implement trait `serde::ser::Serialize`:

This is the code I'm using (linfa_svm/examples/winequality_multi.rs) :

let model = train
        .one_vs_all()?
        .into_iter()
        .map(|(l, x)| (l, params.fit(&x).unwrap()))
        .collect::<MultiClassModel<_, _>>();
 let pred = model.predict(&valid);

//Trying to serialize model
 let save_model = cbor!(model).unwrap();

Could you give me a more detailed example? I would appreciate it too much!

Seems like MultiClassModel and linfa-bayes have no support for Serde. Weird. We'll need to add that.

Oh! I will try with normal SVM then, thank you!

One last question,
I already managed to serialize the SVM model without Multiclass to CBOR, using ciborium, I also managed to de-serialize it in another file and convert it to Value, the last step would be to convert it from Value to SVM, any idea how to do this?

Code for creating and exporting the model :

    let model = Svm::<_, bool>::params().pos_neg_weights(50000., 5000.).gaussian_kernel(80.0).fit(&train)?;

    //Serializing the trained model with ciborium
    let value_model : Value = cbor!(model).unwrap();
    let mut vec_model : Vec<u8> = Vec::new();
    let _cebor_writer = ciborium::ser::into_writer(&value_model, &mut vec_model);

    //Esporting it to a .cbor file
    let path: &Path = Path::new("./model.cbor");
    fs::write(path, vec_model).unwrap();

Attempt to use the trained model in other .rs program :

    let mut file = File::open("./model.cbor").unwrap();
    let mut data: Vec<u8> = Vec::new();
    file.read_to_end(&mut data).unwrap();

    let model_value : Value = ciborium::de::from_reader::<Value, _>(&data[..]).unwrap();
    let model: Svm<_, bool> = model_value.deserialized().unwrap(); // Error 

But I keep getting error when trying to converting form ciborium::Value to SVM, the rust-analyzer suggests : consider specifying the generic argument: ::<Svm<_, bool>>, I guess I have to pass the SVM serve-deserializer but I don't know how to do that.

I know this has nothing to do with linfa, but I really think that exporting and importing the models can be very useful.

Thank you!

My bad, it turns out the example SVM model uses Svm<f64, bool> not Svm<_, bool>. I only changed the line to :

let model: Svm<f65, bool> = model_value.deserialized().unwrap();

And it works super cool!