-
-
Notifications
You must be signed in to change notification settings - Fork 249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to do clustering grid search with multiple CPUs / GPUs? #337
Comments
Here is my current library to provide some context use linfa::traits::Fit;
use linfa::traits::Predict;
use linfa::DatasetBase;
use linfa_clustering::KMeans;
use linfa_nn::distance::LInfDist;
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};
use serde_json;
use wasm_bindgen::prelude::*;
// Data types:
#[derive(Serialize, Deserialize)]
struct Embedding {
keyword: String,
embeddings: Vec<f64>,
}
#[derive(Serialize, Deserialize)]
struct EnrichedEmbedding {
embedding: Embedding,
cluster: usize,
is_main_keyword_in_cluster: bool,
}
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);
}
#[wasm_bindgen]
pub fn greet(name: &str) -> String {
format!("Hello, {}!", name)
}
// TODO - If there are no keywords then raise an error:
#[wasm_bindgen]
pub fn cluster_embeddings_with_kmeans(
json_embeddings: &str,
n_clusters: usize,
) -> Result<String, JsValue> {
let rng = Xoshiro256Plus::seed_from_u64(42);
// Deserialize JSON embeddings:
let embeddings: Vec<Embedding> =
serde_json::from_str(json_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))?;
println!("Number of embeddings: {}", embeddings.len());
// If there are more than 100,000 embeddings:
if embeddings.len() > 100000 {
return Err(JsValue::from_str(
"The number of embeddings is too large. Please use a smaller dataset.",
));
}
if embeddings.len() == 0 {
return Err(JsValue::from_str(
"The number of embeddings is 0. Please provide some embeddings.",
));
}
// Convert embeddings to ndarray
let rows = embeddings.len();
let cols = embeddings[0].embeddings.len();
let flattened: Vec<f64> = embeddings
.iter()
.flat_map(|e| e.embeddings.clone())
.collect();
let array = Array2::from_shape_vec((rows, cols), flattened)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let dataset = DatasetBase::from(array);
log("Clustering embeddings in Rust...");
// Cluster embeddings in Rust:
let model = KMeans::params_with(n_clusters, rng, LInfDist)
.max_n_iterations(1000)
.fit(&dataset)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
log("Finished clustering embeddings in Rust");
log("Assigning points to clusters...");
// Assign each point to a cluster using the set of centroids found using `fit`
let dataset = model.predict(dataset);
let DatasetBase {
records, targets, ..
} = dataset;
// Assuming you want to correlate the original embeddings with their cluster assignments
let enriched_embeddings: Vec<EnrichedEmbedding> = embeddings
.into_iter()
.zip(targets.iter())
.map(|(embedding, &cluster)| {
EnrichedEmbedding {
embedding,
cluster: cluster as usize,
is_main_keyword_in_cluster: false, // Placeholder logic here
}
})
.collect();
// Serialize the enriched embeddings
serde_json::to_string(&enriched_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray_rand::rand::rngs::mock;
use wasm_bindgen_test::*;
use web_sys::console::assert;
#[test]
fn testing_greeting() {
assert_eq!(greet("world"), "Hello, world!");
}
#[wasm_bindgen_test]
fn test_cluster_embeddings() {
let mock_json = r#"
[
{
"keyword": "rust",
"embeddings": [0.1, 0.2, 0.3]
},
{
"keyword": "wasm",
"embeddings": [0.4, 0.5, 0.6]
}
]
"#;
let n_clusters = 2; // For simplicity, choose a small number of clusters
// Call the function with the mocked JSON and the number of clusters
let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
// Check that the function succeeded
assert!(result.is_ok());
// Deserialize the result to verify its structure
let enriched_embeddings: Vec<EnrichedEmbedding> =
serde_json::from_str(&result.unwrap()).unwrap();
// Verify that each embedding has been assigned a cluster
assert_eq!(enriched_embeddings.len(), 2);
for enriched_embedding in enriched_embeddings {
assert!(enriched_embedding.cluster < n_clusters);
}
}
#[wasm_bindgen_test]
fn test_cluster_embeddings_with_no_embeddings() {
let mock_json = r#"
[]
"#;
let n_clusters = 2; // For simplicity, choose a small number of clusters
let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
assert!(result.is_err())
}
#[wasm_bindgen_test]
fn test_cluster_embeddings_with_large_dataset() {
// Mock over 100k embeddings to trigger an error:
let mock_json = r#"
{
"keyword": "rust",
"embeddings": [0.1, 0.2, 0.3]
},
{
"keyword": "wasm",
"embeddings": [0.4, 0.5, 0.6]
}
"#;
// Now make the mock_json a string of 100k embeddings:
let mut mock_json_new = String::from("[");
for _ in 0..100000 {
mock_json_new.push_str(&mock_json);
}
mock_json_new.push_str("]");
let n_clusters = 2; // For simplicity, choose a small number of clusters
// Call the function with the mocked JSON and the number of clusters
let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
// Check that the function failed:
assert!(result.is_err());
#[wasm_bindgen_test]
fn test_cluster_embeddings_with_3k_embeddings() {
let mut mock_json_new = String::from("[");
let single_embedding = r#"{"keyword": "rust", "embeddings": [0.1, 0.2, 0.3]}"#;
for i in 0..3000 {
if i > 0 {
mock_json_new.push(',');
}
mock_json_new.push_str(single_embedding);
}
mock_json_new.push(']');
let n_clusters = 2; // For simplicity, choose a small number of clusters
// Call the function with the mocked JSON and the number of clusters
let result = cluster_embeddings_with_kmeans(&mock_json_new, n_clusters);
assert!(result.is_ok());
}
// Call the function with the mocked JSON and the number of clusters
}
} |
Bump on this? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currently i'm building a wasm project that will expose some clustering functionality to the browser.
Questions:
I'm looking to use all of these:
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/dbscan.rs
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/kmeans.rs
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/optics.rs
Thanks in advance, and great package btw!
The text was updated successfully, but these errors were encountered: