rust-ml / linfa

A Rust machine learning framework.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to do clustering grid search with multiple CPUs / GPUs?

jamesaphoenix opened this issue · comments

Currently i'm building a wasm project that will expose some clustering functionality to the browser.

Questions:

  • Do we have grid search functionality? Or should I simply loop over multiple model.fit calls sequentially?
  • What's the easiest way to implement this for 3x clustering techniques?

I'm looking to use all of these:
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/dbscan.rs
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/kmeans.rs
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/optics.rs

  • Also do I need to implement multi-core processing similar to joblib in python? Or is this handled by linfa?

Thanks in advance, and great package btw!

Here is my current library to provide some context

use linfa::traits::Fit;
use linfa::traits::Predict;
use linfa::DatasetBase;
use linfa_clustering::KMeans;
use linfa_nn::distance::LInfDist;
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};
use serde_json;
use wasm_bindgen::prelude::*;

// Data types:
#[derive(Serialize, Deserialize)]
struct Embedding {
    keyword: String,
    embeddings: Vec<f64>,
}

#[derive(Serialize, Deserialize)]
struct EnrichedEmbedding {
    embedding: Embedding,
    cluster: usize,
    is_main_keyword_in_cluster: bool,
}

#[wasm_bindgen]
extern "C" {
    #[wasm_bindgen(js_namespace = console)]
    fn log(s: &str);
}

#[wasm_bindgen]
pub fn greet(name: &str) -> String {
    format!("Hello, {}!", name)
}

// TODO - If there are no keywords then raise an error:

#[wasm_bindgen]
pub fn cluster_embeddings_with_kmeans(
    json_embeddings: &str,
    n_clusters: usize,
) -> Result<String, JsValue> {
    let rng = Xoshiro256Plus::seed_from_u64(42);

    // Deserialize JSON embeddings:
    let embeddings: Vec<Embedding> =
        serde_json::from_str(json_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))?;

    println!("Number of embeddings: {}", embeddings.len());

    // If there are more than 100,000 embeddings:
    if embeddings.len() > 100000 {
        return Err(JsValue::from_str(
            "The number of embeddings is too large. Please use a smaller dataset.",
        ));
    }

    if embeddings.len() == 0 {
        return Err(JsValue::from_str(
            "The number of embeddings is 0. Please provide some embeddings.",
        ));
    }

    // Convert embeddings to ndarray
    let rows = embeddings.len();
    let cols = embeddings[0].embeddings.len();
    let flattened: Vec<f64> = embeddings
        .iter()
        .flat_map(|e| e.embeddings.clone())
        .collect();
    let array = Array2::from_shape_vec((rows, cols), flattened)
        .map_err(|e| JsValue::from_str(&e.to_string()))?;
    let dataset = DatasetBase::from(array);

    log("Clustering embeddings in Rust...");

    // Cluster embeddings in Rust:
    let model = KMeans::params_with(n_clusters, rng, LInfDist)
        .max_n_iterations(1000)
        .fit(&dataset)
        .map_err(|e| JsValue::from_str(&e.to_string()))?;

    log("Finished clustering embeddings in Rust");
    log("Assigning points to clusters...");

    // Assign each point to a cluster using the set of centroids found using `fit`
    let dataset = model.predict(dataset);
    let DatasetBase {
        records, targets, ..
    } = dataset;

    // Assuming you want to correlate the original embeddings with their cluster assignments
    let enriched_embeddings: Vec<EnrichedEmbedding> = embeddings
        .into_iter()
        .zip(targets.iter())
        .map(|(embedding, &cluster)| {
            EnrichedEmbedding {
                embedding,
                cluster: cluster as usize,
                is_main_keyword_in_cluster: false, // Placeholder logic here
            }
        })
        .collect();

    // Serialize the enriched embeddings
    serde_json::to_string(&enriched_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray_rand::rand::rngs::mock;
    use wasm_bindgen_test::*;
    use web_sys::console::assert;

    #[test]
    fn testing_greeting() {
        assert_eq!(greet("world"), "Hello, world!");
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings() {
        let mock_json = r#"
            [
                {
                    "keyword": "rust",
                    "embeddings": [0.1, 0.2, 0.3]
                },
                {
                    "keyword": "wasm",
                    "embeddings": [0.4, 0.5, 0.6]
                }
            ]
        "#;

        let n_clusters = 2; // For simplicity, choose a small number of clusters

        // Call the function with the mocked JSON and the number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);

        // Check that the function succeeded
        assert!(result.is_ok());

        // Deserialize the result to verify its structure
        let enriched_embeddings: Vec<EnrichedEmbedding> =
            serde_json::from_str(&result.unwrap()).unwrap();

        // Verify that each embedding has been assigned a cluster
        assert_eq!(enriched_embeddings.len(), 2);
        for enriched_embedding in enriched_embeddings {
            assert!(enriched_embedding.cluster < n_clusters);
        }
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings_with_no_embeddings() {
        let mock_json = r#"
            []
        "#;

        let n_clusters = 2; // For simplicity, choose a small number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
        assert!(result.is_err())
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings_with_large_dataset() {
        // Mock over 100k embeddings to trigger an error:
        let mock_json = r#"
                    {
                    "keyword": "rust",
                    "embeddings": [0.1, 0.2, 0.3]
                },
                {
                    "keyword": "wasm",
                    "embeddings": [0.4, 0.5, 0.6]
                }
            "#;

        // Now make the mock_json a string of 100k embeddings:
        let mut mock_json_new = String::from("[");
        for _ in 0..100000 {
            mock_json_new.push_str(&mock_json);
        }
        mock_json_new.push_str("]");
        let n_clusters = 2; // For simplicity, choose a small number of clusters

        // Call the function with the mocked JSON and the number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);

        // Check that the function failed:
        assert!(result.is_err());

        #[wasm_bindgen_test]
        fn test_cluster_embeddings_with_3k_embeddings() {
            let mut mock_json_new = String::from("[");
            let single_embedding = r#"{"keyword": "rust", "embeddings": [0.1, 0.2, 0.3]}"#;
            for i in 0..3000 {
                if i > 0 {
                    mock_json_new.push(',');
                }
                mock_json_new.push_str(single_embedding);
            }
            mock_json_new.push(']');

            let n_clusters = 2; // For simplicity, choose a small number of clusters

            // Call the function with the mocked JSON and the number of clusters
            let result = cluster_embeddings_with_kmeans(&mock_json_new, n_clusters);
            assert!(result.is_ok());
        }

        // Call the function with the mocked JSON and the number of clusters
    }
}

Bump on this?