rust-ml / linfa

A Rust machine learning framework.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How I can save a logistic model for serving?

MrDataPsycho opened this issue · comments

Hi,
I have trained a logistic model as follows:

fn train () {
    let dataset = get_dataset();
    let model = LogisticRegression::default()
    .max_iterations(500)
    .gradient_tolerance(0.0001)
    .fit(&dataset)
    .expect("Can not train the model");
}

Now I want to save the model for serving how can I save the model? Here is my dependency list:

[dependencies]
linfa = "0.6.1"
linfa-logistic = "0.6.1"
csv = "1.2.0"
ndarray = "0.15.6"

The models implement the Serde traits, so you should be able to use whatever serde serialization/deserialization implementation you want. see #290

Hi, I have followed the issue to preproduce a logistic. Thats how the code looks like:

fn train () {
    let dataset = get_dataset();
    info!("Step: Start Training the model.");
    let model = LogisticRegression::default()
    .max_iterations(500)
    .gradient_tolerance(0.0001)
    .fit(&dataset)
    .expect("Can not train the model");
    let value_model = cbor!(model).unwrap();
    let mut vec_model = Vec::new();
    let _result = ciborium::ser::into_writer(&value_model, &mut vec_model).unwrap();
    // let prediction = model.predict(&dataset.records);
    // println!("{:?}", prediction);
    let write_path = Path::new("model").join("model.cbor");
    fs::write(write_path.clone(), vec_model).unwrap();
    info!("Model saved at {:?}", write_path.as_path());
}

fn load_model() {
    let dataset = get_dataset();
    let mut data: Vec<u8> = Vec::new();
    let path = Path::new("model").join("model.cbor");
    let mut file = File::open(&path).unwrap();
    file.read_to_end(&mut data).unwrap();
    let model_value = ciborium::de::from_reader::<value::Value, _>(&data[..]).unwrap();
    let model: LogisticRegression<f64> = model_value.deserialized().unwrap();
    let result = model.predict(&dataset.records); //THis does not work
}

But I can not load the model. It says the predict method is not available for the model. Which is true the LogisticRegression has no predict method, there might be some function I need to call. Can anyone help me.

I just tried with FittedLogisticRegression type but could not load the model. I might be doing something really stupid. But this is frustrating to look everywhere on the internet to figure out the solution. Though there is not a single example I found which shows the whole process of building a Logistic Model. I hope when my solution works I will be able to create a complete document on the following processes with linfa.

  • Trainning a Model
  • Validation and Test of a model
  • Save a model
  • Load a model and make prediction out of it
fn load_model() {
    let dataset = get_dataset();
    let mut data: Vec<u8> = Vec::new();
    let path = Path::new("model").join("model.cbor");
    let mut file = File::open(&path).unwrap();
    file.read_to_end(&mut data).unwrap();
    let model_value = ciborium::de::from_reader::<value::Value, _>(&data[..]).unwrap();
    let model: FittedLogisticRegression<f32, bool> = model_value.deserialized().unwrap(); // THis does not work
    model.predict(dataset.records);
    info!("Model loading was also successful!")
}

Here is the error:
thread 'main' panicked at 'called Result::unwrap() on an Err value: Custom("invalid type: integer 1, expected bool")', src/bin/stages/train.rs:114:81
note: run with RUST_BACKTRACE=1 environment variable to display a backtrace

What is the type of your dataset targets? i32? Did you try

let model: FittedLogisticRegression<f32, i32> = model_value.deserialized().unwrap();

There is also an example of serialization/deserialization using rmp-serde in the tests here

Hi, It looks like it's working. Thanks a lot. I need to understand the type system of linfa . Thanks for the help. Closing the issue.