diff --git a/Cargo.lock b/Cargo.lock
index 821d63e..3adbaa4 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2140,6 +2140,7 @@ dependencies = [
"futures",
"leptos",
"parquet",
+ "serde",
"serde_json",
"wasm-bindgen-futures",
"web-sys",
diff --git a/Cargo.toml b/Cargo.toml
index 5de97cc..c1f6964 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,6 +32,7 @@ chrono = { version = "0.4", features = ["wasmbind"] }
futures = "0.3.31"
zstd = { version = "*", features = ["wasm", "thin"], default-features = false }
zstd-sys = { version = "=2.0.9", default-features = false }
+serde = { version = "1.0", features = ["derive"] }
[profile.release]
strip = true
diff --git a/src/main.rs b/src/main.rs
index 17b2faa..e2120c2 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -387,16 +387,23 @@ fn App() -> impl IntoView {
}}
{move || {
+
file_bytes
.get()
.map(|_| {
- view! {
-
+ match file_content.get_untracked() {
+ Some(info) => {
+ view! {
+
+ }
+ },
+ None => view! {}.into_view(),
}
})
}}
@@ -410,7 +417,7 @@ fn App() -> impl IntoView {
} else {
let physical_plan = physical_plan.get().unwrap();
view! {
-
+
}
.into_view()
}
diff --git a/src/query_input.rs b/src/query_input.rs
index 905afec..8e10bbf 100644
--- a/src/query_input.rs
+++ b/src/query_input.rs
@@ -12,6 +12,10 @@ use datafusion::{
};
use leptos::*;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
+use serde_json::json;
+use wasm_bindgen::{JsCast, JsValue};
+use wasm_bindgen_futures::JsFuture;
+use web_sys::{js_sys, Headers, Request, RequestInit, RequestMode, Response};
use crate::ParquetInfo;
@@ -78,26 +82,55 @@ pub fn QueryInput(
set_sql_query: WriteSignal
,
file_name: ReadSignal,
execute_query: Arc,
+ schema: SchemaRef,
) -> impl IntoView {
- let key_down = execute_query.clone();
+ let key_down_schema = schema.clone();
+ let key_down_exec = execute_query.clone();
+ let file_name_s = file_name.get_untracked();
+ let key_down = move |ev: web_sys::KeyboardEvent| {
+ if ev.key() == "Enter" {
+ let input = sql_query.get_untracked();
+ process_user_input(
+ input,
+ key_down_schema.clone(),
+ file_name_s.clone(),
+ key_down_exec.clone(),
+ set_sql_query.clone(),
+ );
+ }
+ };
+
+ let key_down_exec = execute_query.clone();
+ let button_press_schema = schema.clone();
+ let file_name_s = file_name.get_untracked();
+ let button_press = move |_ev: web_sys::MouseEvent| {
+ let input = sql_query.get_untracked();
+ process_user_input(
+ input,
+ button_press_schema.clone(),
+ file_name_s.clone(),
+ key_down_exec.clone(),
+ set_sql_query.clone(),
+ );
+ };
+
+ let default_query = format!("select * from \"{}\" limit 10", file_name.get_untracked());
+
view! {
}
}
+
+fn process_user_input(
+ input: String,
+ schema: SchemaRef,
+ file_name: String,
+ exec: Arc,
+ set_sql_query: WriteSignal,
+) {
+ // if the input seems to be a SQL query, return it as is
+ if input.starts_with("select") || input.starts_with("SELECT") {
+ exec(input.clone());
+ set_sql_query(input);
+ return;
+ }
+
+ // otherwise, treat it as some natural language
+
+ let schema_str = schema_to_brief_str(schema);
+ web_sys::console::log_1(&format!("Processing user input: {}", input).into());
+
+ let prompt = format!(
+ "the table name is: {}, the schema of the table is: {}. Field names should be quoted. Please generate a SQL query to answer the following question: {}",
+ file_name, schema_str, input
+ );
+ web_sys::console::log_1(&prompt.clone().into());
+
+ spawn_local({
+ let prompt = prompt.clone();
+ async move {
+ let sql = match generate_sql_via_gemini(prompt).await {
+ Ok(response) => response,
+ Err(e) => {
+ web_sys::console::log_1(&e.into());
+ return;
+ }
+ };
+ web_sys::console::log_1(&sql.clone().into());
+ set_sql_query(sql.clone());
+ exec(sql);
+ }
+ });
+}
+
+fn schema_to_brief_str(schema: SchemaRef) -> String {
+ let fields = schema.fields();
+ let field_strs = fields
+ .iter()
+ .map(|field| format!("{}: {}", field.name(), field.data_type()));
+ field_strs.collect::>().join(", ")
+}
+
+// Asynchronous function to call the Gemini API
+async fn generate_sql_via_gemini(prompt: String) -> Result {
+ // this is free tier key, who cares
+ let default_key = "AIzaSyDSEI9ixzvFYQx-e82poEtz8e0bM4omB0Q";
+
+ // Define the API endpoint
+ let url = format!(
+ "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={}",
+ default_key
+ );
+
+ // Build the JSON payload
+ let payload = json!({
+ "contents": [
+ {
+ "role": "user",
+ "parts": [
+ {
+ "text": prompt
+ }
+ ]
+ }
+ ],
+ "generationConfig": {
+ "temperature": 1,
+ "topK": 40,
+ "topP": 0.95,
+ "maxOutputTokens": 8192,
+ "responseMimeType": "application/json",
+ "responseSchema": {
+ "type": "object",
+ "properties": {
+ "sql": {
+ "type": "string"
+ }
+ }
+ }
+ }
+ });
+
+ // Initialize Request
+ let opts = RequestInit::new();
+ opts.set_method("POST");
+ opts.set_mode(RequestMode::Cors);
+
+ // Set headers
+ let headers = Headers::new().map_err(|e| format!("Failed to create headers: {:?}", e))?;
+ headers
+ .set("Content-Type", "application/json")
+ .map_err(|e| format!("Failed to set Content-Type: {:?}", e))?;
+ opts.set_headers(&headers);
+
+ // Set body
+ let body =
+ serde_json::to_string(&payload).map_err(|e| format!("JSON serialization error: {}", e))?;
+ opts.set_body(&JsValue::from_str(&body));
+
+ // Create Request
+ let request = Request::new_with_str_and_init(&url, &opts)
+ .map_err(|e| format!("Request creation failed: {:?}", e))?;
+
+ // Send the request
+ let window = web_sys::window().ok_or("No global `window` exists")?;
+ let response_value = JsFuture::from(window.fetch_with_request(&request))
+ .await
+ .map_err(|e| format!("Fetch error: {:?}", e))?;
+
+ // Convert the response to a WebSys Response object
+ let response: Response = response_value
+ .dyn_into()
+ .map_err(|e| format!("Response casting failed: {:?}", e))?;
+
+ if !response.ok() {
+ return Err(format!(
+ "Network response was not ok: {}",
+ response.status()
+ ));
+ }
+
+ // Parse the JSON response
+ let json = JsFuture::from(
+ response
+ .json()
+ .map_err(|e| format!("Failed to parse JSON: {:?}", e))?,
+ )
+ .await
+ .map_err(|e| format!("JSON parsing error: {:?}", e))?;
+
+ // Parse the response to extract just the SQL query
+ let json_value: serde_json::Value = serde_json::from_str(
+ &js_sys::JSON::stringify(&json)
+ .map_err(|e| format!("Failed to stringify JSON: {:?}", e))?
+ .as_string()
+ .ok_or("Failed to convert to string")?,
+ )
+ .map_err(|e| format!("Failed to parse JSON value: {:?}", e))?;
+
+ // Navigate the JSON structure to extract the SQL
+ let sql = json_value
+ .get("candidates")
+ .and_then(|c| c.get(0))
+ .and_then(|c| c.get("content"))
+ .and_then(|c| c.get("parts"))
+ .and_then(|p| p.get(0))
+ .and_then(|p| p.get("text"))
+ .and_then(|t| t.as_str())
+ .ok_or("Failed to extract SQL from response")?;
+
+ // Parse the inner JSON string to get the final SQL
+ let sql_obj: serde_json::Value =
+ serde_json::from_str(sql).map_err(|e| format!("Failed to parse SQL JSON: {:?}", e))?;
+
+ let final_sql = sql_obj
+ .get("sql")
+ .and_then(|s| s.as_str())
+ .ok_or("Failed to extract SQL field")?
+ .to_string();
+
+ Ok(final_sql)
+}
diff --git a/src/query_results.rs b/src/query_results.rs
index 8e68ba5..4c6e728 100644
--- a/src/query_results.rs
+++ b/src/query_results.rs
@@ -14,6 +14,7 @@ use leptos::*;
#[component]
pub fn QueryResults(
+ sql_query: String,
query_result: Vec,
physical_plan: Arc,
) -> impl IntoView {
@@ -21,6 +22,9 @@ pub fn QueryResults(
view! {