diff --git a/Cargo.lock b/Cargo.lock index 821d63e..3adbaa4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2140,6 +2140,7 @@ dependencies = [ "futures", "leptos", "parquet", + "serde", "serde_json", "wasm-bindgen-futures", "web-sys", diff --git a/Cargo.toml b/Cargo.toml index 5de97cc..c1f6964 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ chrono = { version = "0.4", features = ["wasmbind"] } futures = "0.3.31" zstd = { version = "*", features = ["wasm", "thin"], default-features = false } zstd-sys = { version = "=2.0.9", default-features = false } +serde = { version = "1.0", features = ["derive"] } [profile.release] strip = true diff --git a/src/main.rs b/src/main.rs index 17b2faa..e2120c2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -387,16 +387,23 @@ fn App() -> impl IntoView { }}
{move || { + file_bytes .get() .map(|_| { - view! { - + match file_content.get_untracked() { + Some(info) => { + view! { + + } + }, + None => view! {}.into_view(), } }) }} @@ -410,7 +417,7 @@ fn App() -> impl IntoView { } else { let physical_plan = physical_plan.get().unwrap(); view! { - + } .into_view() } diff --git a/src/query_input.rs b/src/query_input.rs index 905afec..8e10bbf 100644 --- a/src/query_input.rs +++ b/src/query_input.rs @@ -12,6 +12,10 @@ use datafusion::{ }; use leptos::*; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use serde_json::json; +use wasm_bindgen::{JsCast, JsValue}; +use wasm_bindgen_futures::JsFuture; +use web_sys::{js_sys, Headers, Request, RequestInit, RequestMode, Response}; use crate::ParquetInfo; @@ -78,26 +82,55 @@ pub fn QueryInput( set_sql_query: WriteSignal, file_name: ReadSignal, execute_query: Arc, + schema: SchemaRef, ) -> impl IntoView { - let key_down = execute_query.clone(); + let key_down_schema = schema.clone(); + let key_down_exec = execute_query.clone(); + let file_name_s = file_name.get_untracked(); + let key_down = move |ev: web_sys::KeyboardEvent| { + if ev.key() == "Enter" { + let input = sql_query.get_untracked(); + process_user_input( + input, + key_down_schema.clone(), + file_name_s.clone(), + key_down_exec.clone(), + set_sql_query.clone(), + ); + } + }; + + let key_down_exec = execute_query.clone(); + let button_press_schema = schema.clone(); + let file_name_s = file_name.get_untracked(); + let button_press = move |_ev: web_sys::MouseEvent| { + let input = sql_query.get_untracked(); + process_user_input( + input, + button_press_schema.clone(), + file_name_s.clone(), + key_down_exec.clone(), + set_sql_query.clone(), + ); + }; + + let default_query = format!("select * from \"{}\" limit 10", file_name.get_untracked()); + view! {
} } + +fn process_user_input( + input: String, + schema: SchemaRef, + file_name: String, + exec: Arc, + set_sql_query: WriteSignal, +) { + // if the input seems to be a SQL query, return it as is + if input.starts_with("select") || input.starts_with("SELECT") { + exec(input.clone()); + set_sql_query(input); + return; + } + + // otherwise, treat it as some natural language + + let schema_str = schema_to_brief_str(schema); + web_sys::console::log_1(&format!("Processing user input: {}", input).into()); + + let prompt = format!( + "the table name is: {}, the schema of the table is: {}. Field names should be quoted. Please generate a SQL query to answer the following question: {}", + file_name, schema_str, input + ); + web_sys::console::log_1(&prompt.clone().into()); + + spawn_local({ + let prompt = prompt.clone(); + async move { + let sql = match generate_sql_via_gemini(prompt).await { + Ok(response) => response, + Err(e) => { + web_sys::console::log_1(&e.into()); + return; + } + }; + web_sys::console::log_1(&sql.clone().into()); + set_sql_query(sql.clone()); + exec(sql); + } + }); +} + +fn schema_to_brief_str(schema: SchemaRef) -> String { + let fields = schema.fields(); + let field_strs = fields + .iter() + .map(|field| format!("{}: {}", field.name(), field.data_type())); + field_strs.collect::>().join(", ") +} + +// Asynchronous function to call the Gemini API +async fn generate_sql_via_gemini(prompt: String) -> Result { + // this is free tier key, who cares + let default_key = "AIzaSyDSEI9ixzvFYQx-e82poEtz8e0bM4omB0Q"; + + // Define the API endpoint + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={}", + default_key + ); + + // Build the JSON payload + let payload = json!({ + "contents": [ + { + "role": "user", + "parts": [ + { + "text": prompt + } + ] + } + ], + "generationConfig": { + "temperature": 1, + "topK": 40, + "topP": 0.95, + "maxOutputTokens": 8192, + "responseMimeType": "application/json", + "responseSchema": { + "type": "object", + "properties": { + "sql": { + "type": "string" + } + } + } + } + }); + + // Initialize Request + let opts = RequestInit::new(); + opts.set_method("POST"); + opts.set_mode(RequestMode::Cors); + + // Set headers + let headers = Headers::new().map_err(|e| format!("Failed to create headers: {:?}", e))?; + headers + .set("Content-Type", "application/json") + .map_err(|e| format!("Failed to set Content-Type: {:?}", e))?; + opts.set_headers(&headers); + + // Set body + let body = + serde_json::to_string(&payload).map_err(|e| format!("JSON serialization error: {}", e))?; + opts.set_body(&JsValue::from_str(&body)); + + // Create Request + let request = Request::new_with_str_and_init(&url, &opts) + .map_err(|e| format!("Request creation failed: {:?}", e))?; + + // Send the request + let window = web_sys::window().ok_or("No global `window` exists")?; + let response_value = JsFuture::from(window.fetch_with_request(&request)) + .await + .map_err(|e| format!("Fetch error: {:?}", e))?; + + // Convert the response to a WebSys Response object + let response: Response = response_value + .dyn_into() + .map_err(|e| format!("Response casting failed: {:?}", e))?; + + if !response.ok() { + return Err(format!( + "Network response was not ok: {}", + response.status() + )); + } + + // Parse the JSON response + let json = JsFuture::from( + response + .json() + .map_err(|e| format!("Failed to parse JSON: {:?}", e))?, + ) + .await + .map_err(|e| format!("JSON parsing error: {:?}", e))?; + + // Parse the response to extract just the SQL query + let json_value: serde_json::Value = serde_json::from_str( + &js_sys::JSON::stringify(&json) + .map_err(|e| format!("Failed to stringify JSON: {:?}", e))? + .as_string() + .ok_or("Failed to convert to string")?, + ) + .map_err(|e| format!("Failed to parse JSON value: {:?}", e))?; + + // Navigate the JSON structure to extract the SQL + let sql = json_value + .get("candidates") + .and_then(|c| c.get(0)) + .and_then(|c| c.get("content")) + .and_then(|c| c.get("parts")) + .and_then(|p| p.get(0)) + .and_then(|p| p.get("text")) + .and_then(|t| t.as_str()) + .ok_or("Failed to extract SQL from response")?; + + // Parse the inner JSON string to get the final SQL + let sql_obj: serde_json::Value = + serde_json::from_str(sql).map_err(|e| format!("Failed to parse SQL JSON: {:?}", e))?; + + let final_sql = sql_obj + .get("sql") + .and_then(|s| s.as_str()) + .ok_or("Failed to extract SQL field")? + .to_string(); + + Ok(final_sql) +} diff --git a/src/query_results.rs b/src/query_results.rs index 8e68ba5..4c6e728 100644 --- a/src/query_results.rs +++ b/src/query_results.rs @@ -14,6 +14,7 @@ use leptos::*; #[component] pub fn QueryResults( + sql_query: String, query_result: Vec, physical_plan: Arc, ) -> impl IntoView { @@ -21,6 +22,9 @@ pub fn QueryResults( view! {
+
+ {sql_query} +