Skip to content

Commit

Permalink
feat: add query cancellation, remove async process
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Ionov committed May 25, 2024
1 parent 5fcd686 commit 0a60dd8
Show file tree
Hide file tree
Showing 24 changed files with 437 additions and 306 deletions.
1 change: 1 addition & 0 deletions src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ tauri-plugin-store = { git = "https://github.com/tauri-apps/plugins-workspace",
sql_lexer = "0.9.4"
futures = "0.3.28"
tokio = { version = "1.32.0", features = ["full"] }
tokio-util = "^0.7.10"
tracing = "0.1.37"
tracing-subscriber = "0.3.17"
sqlparser = "0.46.0"
Expand Down
38 changes: 7 additions & 31 deletions src-tauri/src/bin/main.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]


use state::AppState;
use std::path::PathBuf;
use std::{fs, panic};
use tauri::{Manager, State};
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tracing::error;


use noir::{
database::init::initialize_database,
handlers::{connections, queries},
queues::query::{async_process_model, rs2js},
state::{self, AsyncState},
handlers::{connections, general, queries, task},
state::{self},
utils::{fs::get_app_path, init},
};

Expand All @@ -33,21 +28,12 @@ fn main() {
let path = get_app_path();
let ts = chrono::offset::Utc::now();
let dest = format!("{}/error.log", path.to_str().expect("Failed to get path"));
fs::write(PathBuf::from(dest), format!("{} - {:?}", ts, info)).expect("Failed to write error log");
fs::write(PathBuf::from(dest), format!("{} - {:?}", ts, info))
.expect("Failed to write error log");
}));

let (async_proc_input_tx, async_proc_input_rx) = mpsc::channel(1);
let (async_proc_output_tx, mut async_proc_output_rx) = mpsc::channel(1);

tauri::Builder::default()
.manage(AsyncState {
tasks: Mutex::new(async_proc_input_tx),
connections: Default::default(),
})
.manage(AppState {
db: Default::default(),
connections: Default::default(),
})
.manage(AppState::default())
.plugin(tauri_plugin_store::Builder::default().build())
.plugin(tauri_plugin_window_state::Builder::default().build())
.plugin(tauri_plugin_single_instance::init(|app, argv, cwd| {
Expand All @@ -64,18 +50,6 @@ fn main() {
let db = initialize_database().expect("Database initialize should succeed");
*app_state.db.lock().expect("Failed to lock db") = Some(db);

tauri::async_runtime::spawn(async move {
async_process_model(async_proc_input_rx, async_proc_output_tx).await
});

tauri::async_runtime::spawn(async move {
loop {
if let Some(output) = async_proc_output_rx.recv().await {
rs2js(output, &handle).await
}
}
});

Ok(())
})
.invoke_handler(tauri::generate_handler![
Expand Down Expand Up @@ -103,6 +77,8 @@ fn main() {
queries::download_json,
queries::download_csv,
queries::invalidate_query,
task::cancel_task_token,
general::request_port_forward,
])
.run(tauri::generate_context!())
.expect("error while running tauri application");
Expand Down
17 changes: 7 additions & 10 deletions src-tauri/src/engine/mysql/sql_to_json.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
use std::collections::HashMap;

use chrono::{DateTime, Utc};
use serde_json::{self, json, Value};
use sqlx::mysql::MySqlRow;
use sqlx::Decode;
use sqlx::{Column, Row, TypeInfo, ValueRef};

pub fn row_to_json(row: MySqlRow) -> Value {
json!(row
.columns()
.iter()
.map(|column| {
let value: Value = sql_to_json(&row, column);
(column.name().to_string(), value)
})
.collect::<HashMap<_, _>>())
let mut object = json!({});
for column in row.columns().iter() {
let value: Value = sql_to_json(&row, column);
let name = column.name().to_string();
object[name] = value;
}
object
}

pub fn sql_to_json(row: &MySqlRow, col: &sqlx::mysql::MySqlColumn) -> Value {
Expand Down
3 changes: 3 additions & 0 deletions src-tauri/src/engine/types/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub enum Mode {
Host,
Socket,
File,
Ssh,
}

impl fmt::Display for Mode {
Expand All @@ -60,6 +61,7 @@ impl fmt::Display for Mode {
Mode::Host => write!(f, "Host"),
Mode::Socket => write!(f, "Socket"),
Mode::File => write!(f, "File"),
Mode::Ssh => write!(f, "File"),
}
}
}
Expand All @@ -71,6 +73,7 @@ impl FromSql for Mode {
"Host" => Ok(Mode::Host),
"Socket" => Ok(Mode::Socket),
"File" => Ok(Mode::File),
"Ssh" => Ok(Mode::Ssh),
_ => Err(types::FromSqlError::InvalidType),
}
}
Expand Down
14 changes: 13 additions & 1 deletion src-tauri/src/engine/types/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::Result;
use serde_json::Value;

use super::config::{ConnectionConfig, ConnectionPool};
use super::result::ResultSet;
use super::result::{ResultSet, TableMetadata};
use crate::database::QueryType;
use crate::engine::exec;

Expand Down Expand Up @@ -35,6 +35,18 @@ impl InitiatedConnection {
exec::get_columns(self, table).await
}

pub async fn get_table_metadata(&self, table: &str) -> Result<TableMetadata> {
let foreign_keys = self.get_foreign_keys(table).await?;
let primary_key = self.get_primary_key(table).await?;
let columns = self.get_columns(Some(table)).await?;
Ok(TableMetadata {
table: table.to_string(),
foreign_keys: Some(foreign_keys),
primary_key: Some(primary_key),
columns: Some(columns),
})
}

pub async fn get_foreign_keys(&self, table: &str) -> Result<Vec<Value>> {
exec::get_foreign_keys(self, table).await
}
Expand Down
2 changes: 2 additions & 0 deletions src-tauri/src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
pub mod connections;
pub mod queries;
pub mod task;

83 changes: 59 additions & 24 deletions src-tauri/src/handlers/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@ use std::{fs::read_to_string, path::PathBuf};

use crate::{
database::QueryType,
queues::query::{QueryTask, QueryTaskEnqueueResult, QueryTaskStatus},
state::{AsyncState, ServiceAccess},
query::{Events, QueryTask, QueryTaskEnqueueResult, QueryTaskResult, QueryTaskStatus},
state::{AppState, ServiceAccess},
utils::{
self,
crypto::md5_hash,
error::{CommandResult, Error},
fs::paginate_file,
fs::{paginate_file, write_query},
},
};
use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlparser::{ast::Statement, dialect::dialect_from_str, parser::Parser};
use std::str;
use tauri::{command, AppHandle, State};
use tauri::{command, AppHandle, Manager, State};
use tokio_util::sync::CancellationToken;
use tracing::info;

fn get_query_type(s: Statement) -> QueryType {
Expand All @@ -32,7 +33,7 @@ fn get_query_type(s: Statement) -> QueryType {
#[command]
pub async fn enqueue_query(
app_handle: AppHandle,
async_state: State<'_, AsyncState>,
state: State<'_, AppState>,
conn_id: String,
tab_idx: usize,
sql: &str,
Expand All @@ -41,39 +42,73 @@ pub async fn enqueue_query(
) -> CommandResult<QueryTaskEnqueueResult> {
info!(sql, conn_id, tab_idx, "enqueue_query");
let conn = app_handle.acquire_connection(conn_id.clone());
// ignore sqlparser when dialect is sqlite and statements contain pragma
let statements = Parser::parse_sql(
dialect_from_str(conn.config.dialect.to_string())
.expect("Failed to get dialect")
.as_ref(),
sql,
)?;
)
.unwrap_or_default();
if statements.is_empty() {
return Err(Error::from(anyhow!("No statements found")));
}
let statements: Vec<(String, QueryType, String)> = statements
.into_iter()
.map(|s| {
let id = conn.config.id.to_string() + &tab_idx.to_string() + &s.to_string();
(s.to_string(), get_query_type(s), md5_hash(&id))
let query_type = get_query_type(s.clone());
let mut statement = s.to_string();
if auto_limit
&& !statement.to_lowercase().contains("limit")
&& query_type == QueryType::Select
{
statement = format!("{} LIMIT 1000", statement);
}
let id = conn.config.id.to_string() + &tab_idx.to_string() + &statement.to_string();
(statement, query_type, md5_hash(&id))
})
.collect();
let async_proc_input_tx = async_state.tasks.lock().await;
let enqueued_ids: Vec<String> = vec![];
let mut binding = state.cancel_tokens.lock().await;
for (idx, stmt) in statements.iter().enumerate() {
let (mut statement, t, id) = stmt.clone();
info!("Got statement {:?}", statement);
if enqueued_ids.contains(&id) {
continue;
}
if auto_limit && !statement.to_lowercase().contains("limit") && t == QueryType::Select {
statement = format!("{} LIMIT 1000", statement);
}
let task = QueryTask::new(conn.clone(), statement, t, id, tab_idx, idx, table.clone());
let res = async_proc_input_tx.send(task).await;
if let Err(e) = res {
return Err(Error::from(e));
}
let token = CancellationToken::new();
let task = QueryTask::new(
conn.clone(),
stmt.to_owned(),
tab_idx,
idx,
table.clone(),
token.clone(),
);
binding.insert(stmt.2.clone(), token);
let handle = app_handle.clone();
tokio::spawn(async move {
tokio::select! {
_ = task.cancel_token.cancelled() => {},
res = task.conn.execute_query(&task.query, task.query_type) => {
match res {
Ok(mut result_set) => {
if let Some(table) = task.table.clone() {
result_set.table = task.conn.get_table_metadata(&table).await.unwrap_or_default();
}
match write_query(&task.id, &result_set) {
Ok(path) => {
handle
.emit_all(Events::QueryFinished.as_str(), QueryTaskResult::success(task, result_set, path))
.expect("Failed to emit query_finished event");
},
Err(e) =>
handle
.emit_all(Events::QueryFinished.as_str(), QueryTaskResult::error(task, e))
.expect("Failed to emit query_finished event"),
}
}
Err(e) =>
handle
.emit_all(Events::QueryFinished.as_str(), QueryTaskResult::error(task, e))
.expect("Failed to emit query_finished event"),
}
}
}
});
}
Ok(QueryTaskEnqueueResult {
conn_id,
Expand Down
19 changes: 19 additions & 0 deletions src-tauri/src/handlers/task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use tauri::AppHandle;
use tracing::info;

use crate::{
state::ServiceAccess,
utils::error::{CommandResult, Error},
};

#[tauri::command]
pub async fn cancel_task_token(app_handle: AppHandle, ids: Vec<String>) -> CommandResult<()> {
info!(?ids, "Cancelling task token");
for id in ids.iter() {
app_handle
.cancel_token(id.clone())
.await
.map_err(Error::from)?;
}
Ok(())
}
7 changes: 4 additions & 3 deletions src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod state;
pub mod database;
pub mod utils;
pub mod engine;
pub mod handlers;
pub mod query;
pub mod queues;
pub mod engine;
pub mod state;
pub mod utils;
Loading

0 comments on commit 0a60dd8

Please sign in to comment.