Skip to content

Commit

Permalink
AI super power
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Nov 24, 2024
1 parent e804c8f commit 8e11b10
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ chrono = { version = "0.4", features = ["wasmbind"] }
futures = "0.3.31"
zstd = { version = "*", features = ["wasm", "thin"], default-features = false }
zstd-sys = { version = "=2.0.9", default-features = false }
serde = { version = "1.0", features = ["derive"] }

[profile.release]
strip = true
Expand Down
23 changes: 15 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,23 @@ fn App() -> impl IntoView {
}}
<div class="mt-4">
{move || {

file_bytes
.get()
.map(|_| {
view! {
<QueryInput
sql_query=sql_query
set_sql_query=set_sql_query
file_name=file_name
execute_query=Arc::new(execute_query)
/>
match file_content.get_untracked() {
Some(info) => {
view! {
<QueryInput
sql_query=sql_query
set_sql_query=set_sql_query
file_name=file_name
execute_query=Arc::new(execute_query)
schema=info.schema
/>
}
},
None => view! {}.into_view(),
}
})
}}
Expand All @@ -410,7 +417,7 @@ fn App() -> impl IntoView {
} else {
let physical_plan = physical_plan.get().unwrap();
view! {
<QueryResults query_result=result physical_plan=physical_plan />
<QueryResults sql_query=sql_query.get_untracked() query_result=result physical_plan=physical_plan />
}
.into_view()
}
Expand Down
224 changes: 214 additions & 10 deletions src/query_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use datafusion::{
};
use leptos::*;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use serde_json::json;
use wasm_bindgen::{JsCast, JsValue};
use wasm_bindgen_futures::JsFuture;
use web_sys::{js_sys, Headers, Request, RequestInit, RequestMode, Response};

use crate::ParquetInfo;

Expand Down Expand Up @@ -78,30 +82,230 @@ pub fn QueryInput(
set_sql_query: WriteSignal<String>,
file_name: ReadSignal<String>,
execute_query: Arc<dyn Fn(String)>,
schema: SchemaRef,
) -> impl IntoView {
let key_down = execute_query.clone();
let key_down_schema = schema.clone();
let key_down_exec = execute_query.clone();
let file_name_s = file_name.get_untracked();
let key_down = move |ev: web_sys::KeyboardEvent| {
if ev.key() == "Enter" {
let input = sql_query.get_untracked();
process_user_input(
input,
key_down_schema.clone(),
file_name_s.clone(),
key_down_exec.clone(),
set_sql_query.clone(),
);
}
};

let key_down_exec = execute_query.clone();
let button_press_schema = schema.clone();
let file_name_s = file_name.get_untracked();
let button_press = move |_ev: web_sys::MouseEvent| {
let input = sql_query.get_untracked();
process_user_input(
input,
button_press_schema.clone(),
file_name_s.clone(),
key_down_exec.clone(),
set_sql_query.clone(),
);
};

let default_query = format!("select * from \"{}\" limit 10", file_name.get_untracked());

view! {
<div class="flex gap-2 items-center">
<input
type="text"
placeholder=move || {
format!("select * from \"{}\" limit 10", file_name.get())
placeholder=default_query
on:input=move |ev| {
let value = event_target_value(&ev);
set_sql_query(value);
}
prop:value=sql_query
on:input=move |ev| set_sql_query(event_target_value(&ev))
on:keydown=move |ev| {
if ev.key() == "Enter" {
key_down(sql_query.get());
}
}
on:keydown=key_down
class="flex-1 px-3 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500"
/>
<button
on:click=move |_| execute_query(sql_query.get())
on:click=button_press
class="px-4 py-2 bg-green-500 text-white rounded-md hover:bg-green-600 whitespace-nowrap"
>
"Run Query"
</button>
</div>
}
}

fn process_user_input(
input: String,
schema: SchemaRef,
file_name: String,
exec: Arc<dyn Fn(String)>,
set_sql_query: WriteSignal<String>,
) {
// if the input seems to be a SQL query, return it as is
if input.starts_with("select") || input.starts_with("SELECT") {
exec(input.clone());
set_sql_query(input);
return;
}

// otherwise, treat it as some natural language

let schema_str = schema_to_brief_str(schema);
web_sys::console::log_1(&format!("Processing user input: {}", input).into());

let prompt = format!(
"the table name is: {}, the schema of the table is: {}. Field names should be quoted. Please generate a SQL query to answer the following question: {}",
file_name, schema_str, input
);
web_sys::console::log_1(&prompt.clone().into());

spawn_local({
let prompt = prompt.clone();
async move {
let sql = match generate_sql_via_gemini(prompt).await {
Ok(response) => response,
Err(e) => {
web_sys::console::log_1(&e.into());
return;
}
};
web_sys::console::log_1(&sql.clone().into());
set_sql_query(sql.clone());
exec(sql);
}
});
}

fn schema_to_brief_str(schema: SchemaRef) -> String {
let fields = schema.fields();
let field_strs = fields
.iter()
.map(|field| format!("{}: {}", field.name(), field.data_type()));
field_strs.collect::<Vec<_>>().join(", ")
}

// Asynchronous function to call the Gemini API
async fn generate_sql_via_gemini(prompt: String) -> Result<String, String> {
// this is free tier key, who cares
let default_key = "AIzaSyDSEI9ixzvFYQx-e82poEtz8e0bM4omB0Q";

// Define the API endpoint
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={}",
default_key
);

// Build the JSON payload
let payload = json!({
"contents": [
{
"role": "user",
"parts": [
{
"text": prompt
}
]
}
],
"generationConfig": {
"temperature": 1,
"topK": 40,
"topP": 0.95,
"maxOutputTokens": 8192,
"responseMimeType": "application/json",
"responseSchema": {
"type": "object",
"properties": {
"sql": {
"type": "string"
}
}
}
}
});

// Initialize Request
let opts = RequestInit::new();
opts.set_method("POST");
opts.set_mode(RequestMode::Cors);

// Set headers
let headers = Headers::new().map_err(|e| format!("Failed to create headers: {:?}", e))?;
headers
.set("Content-Type", "application/json")
.map_err(|e| format!("Failed to set Content-Type: {:?}", e))?;
opts.set_headers(&headers);

// Set body
let body =
serde_json::to_string(&payload).map_err(|e| format!("JSON serialization error: {}", e))?;
opts.set_body(&JsValue::from_str(&body));

// Create Request
let request = Request::new_with_str_and_init(&url, &opts)
.map_err(|e| format!("Request creation failed: {:?}", e))?;

// Send the request
let window = web_sys::window().ok_or("No global `window` exists")?;
let response_value = JsFuture::from(window.fetch_with_request(&request))
.await
.map_err(|e| format!("Fetch error: {:?}", e))?;

// Convert the response to a WebSys Response object
let response: Response = response_value
.dyn_into()
.map_err(|e| format!("Response casting failed: {:?}", e))?;

if !response.ok() {
return Err(format!(
"Network response was not ok: {}",
response.status()
));
}

// Parse the JSON response
let json = JsFuture::from(
response
.json()
.map_err(|e| format!("Failed to parse JSON: {:?}", e))?,
)
.await
.map_err(|e| format!("JSON parsing error: {:?}", e))?;

// Parse the response to extract just the SQL query
let json_value: serde_json::Value = serde_json::from_str(
&js_sys::JSON::stringify(&json)
.map_err(|e| format!("Failed to stringify JSON: {:?}", e))?
.as_string()
.ok_or("Failed to convert to string")?,
)
.map_err(|e| format!("Failed to parse JSON value: {:?}", e))?;

// Navigate the JSON structure to extract the SQL
let sql = json_value
.get("candidates")
.and_then(|c| c.get(0))
.and_then(|c| c.get("content"))
.and_then(|c| c.get("parts"))
.and_then(|p| p.get(0))
.and_then(|p| p.get("text"))
.and_then(|t| t.as_str())
.ok_or("Failed to extract SQL from response")?;

// Parse the inner JSON string to get the final SQL
let sql_obj: serde_json::Value =
serde_json::from_str(sql).map_err(|e| format!("Failed to parse SQL JSON: {:?}", e))?;

let final_sql = sql_obj
.get("sql")
.and_then(|s| s.as_str())
.ok_or("Failed to extract SQL field")?
.to_string();

Ok(final_sql)
}
4 changes: 4 additions & 0 deletions src/query_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ use leptos::*;

#[component]
pub fn QueryResults(
sql_query: String,
query_result: Vec<RecordBatch>,
physical_plan: Arc<dyn ExecutionPlan>,
) -> impl IntoView {
let (active_tab, set_active_tab) = create_signal("results".to_string());

view! {
<div class="mt-4 p-4 bg-white border border-gray-300 rounded-md">
<div class="mb-4 p-3 bg-gray-50 rounded border border-gray-200 font-mono text-sm overflow-x-auto">
{sql_query}
</div>
<div class="mb-4 border-b border-gray-300">
<button
class=move || format!(
Expand Down

0 comments on commit 8e11b10

Please sign in to comment.