Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add set_update_hook on SqliteConnection #3260

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ impl EstablishParams {
transaction_depth: 0,
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None
})
}
}
73 changes: 70 additions & 3 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::cmp::Ordering;
use std::ffi::CStr;
use std::fmt::Write;
use std::fmt::{self, Debug, Formatter};
use std::os::raw::{c_int, c_void};
Expand All @@ -8,7 +9,7 @@ use std::ptr::NonNull;
use futures_core::future::BoxFuture;
use futures_intrusive::sync::MutexGuard;
use futures_util::future;
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler};
use libsqlite3_sys::{sqlite3, sqlite3_progress_handler, sqlite3_update_hook};

pub(crate) use handle::ConnectionHandle;
use sqlx_core::common::StatementCache;
Expand All @@ -20,7 +21,7 @@ use sqlx_core::transaction::Transaction;
use crate::connection::establish::EstablishParams;
use crate::connection::worker::ConnectionWorker;
use crate::options::OptimizeOnClose;
use crate::statement::VirtualStatement;
use crate::statement::{SqliteOperation, VirtualStatement};
use crate::{Sqlite, SqliteConnectOptions};

pub(crate) mod collation;
Expand Down Expand Up @@ -58,6 +59,11 @@ pub struct LockedSqliteHandle<'a> {
pub(crate) struct Handler(NonNull<dyn FnMut() -> bool + Send + 'static>);
unsafe impl Send for Handler {}

pub(crate) struct UpdateHandler(
NonNull<dyn FnMut(SqliteOperation, String, String, i64) + Send + 'static>,
);
unsafe impl Send for UpdateHandler {}

pub(crate) struct ConnectionState {
pub(crate) handle: ConnectionHandle,

Expand All @@ -71,14 +77,25 @@ pub(crate) struct ConnectionState {
/// Stores the progress handler set on the current connection. If the handler returns `false`,
/// the query is interrupted.
progress_handler_callback: Option<Handler>,

update_hook_callback: Option<UpdateHandler>,
}

impl ConnectionState {
/// Drops the `progress_handler_callback` if it exists.
pub(crate) fn remove_progress_handler(&mut self) {
if let Some(mut handler) = self.progress_handler_callback.take() {
unsafe {
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, 0 as *mut _);
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, std::ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}

pub(crate) fn remove_update_hook(&mut self) {
if let Some(mut handler) = self.update_hook_callback.take() {
unsafe {
sqlite3_update_hook(self.handle.as_ptr(), None, std::ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
Expand Down Expand Up @@ -215,6 +232,31 @@ where
}
}

extern "C" fn update_hook<F>(
callback: *mut c_void,
operation: c_int,
database: *const i8,
table: *const i8,
rowid: i64,
) where
F: FnMut(SqliteOperation, String, String, i64),
{
unsafe {
let r = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let database_str = CStr::from_ptr(database).to_str().unwrap();
let table_str = CStr::from_ptr(table).to_str().unwrap();
(*callback)(
operation.into(),
database_str.to_owned(),
table_str.to_owned(),
gridbox marked this conversation as resolved.
Show resolved Hide resolved
rowid,
)
});
r.unwrap_or_default()
gridbox marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl LockedSqliteHandle<'_> {
/// Returns the underlying sqlite3* connection handle.
///
Expand Down Expand Up @@ -279,17 +321,42 @@ impl LockedSqliteHandle<'_> {
}
}

pub fn set_update_hook<F>(&mut self, callback: F)
where
F: FnMut(SqliteOperation, String, String, i64) + Send + 'static,
gridbox marked this conversation as resolved.
Show resolved Hide resolved
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_update_hook();
self.guard.update_hook_callback = Some(UpdateHandler(callback));

sqlite3_update_hook(
self.as_raw_handle().as_mut(),
Some(update_hook::<F>),
handler,
);
}
}

/// Removes the progress handler on a database connection. The method does nothing if no handler was set.
pub fn remove_progress_handler(&mut self) {
self.guard.remove_progress_handler();
}

pub fn remove_update_hook(&mut self) {
self.guard.remove_update_hook();
}
}

impl Drop for ConnectionState {
fn drop(&mut self) {
// explicitly drop statements before the connection handle is dropped
self.statements.clear();
self.remove_progress_handler();
self.remove_update_hook();
}
}

Expand Down
1 change: 1 addition & 0 deletions sqlx-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub use options::{
};
pub use query_result::SqliteQueryResult;
pub use row::SqliteRow;
pub use statement::SqliteOperation;
pub use statement::SqliteStatement;
pub use transaction::SqliteTransactionManager;
pub use type_info::SqliteTypeInfo;
Expand Down
20 changes: 20 additions & 0 deletions sqlx-sqlite/src/statement/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::column::ColumnIndex;
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::{Sqlite, SqliteArguments, SqliteColumn, SqliteTypeInfo};
use libsqlite3_sys::{SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE};
use sqlx_core::{Either, HashMap};
use std::borrow::Cow;
use std::sync::Arc;
Expand Down Expand Up @@ -77,3 +78,22 @@ impl ColumnIndex<SqliteStatement<'_>> for &'_ str {
// }
// }
// }

#[derive(Debug, PartialEq, Eq)]
pub enum SqliteOperation {
Insert,
Update,
Delete,
Unknown,
}
gridbox marked this conversation as resolved.
Show resolved Hide resolved

impl From<i32> for SqliteOperation {
fn from(value: i32) -> Self {
match value {
SQLITE_INSERT => SqliteOperation::Insert,
SQLITE_UPDATE => SqliteOperation::Update,
SQLITE_DELETE => SqliteOperation::Delete,
_ => SqliteOperation::Unknown,
gridbox marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
64 changes: 63 additions & 1 deletion tests/sqlite/sqlite.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use futures::TryStreamExt;
use rand::{Rng, SeedableRng};
use rand_xoshiro::Xoshiro256PlusPlus;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions};
use sqlx::{
query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row,
SqliteConnection, SqlitePool, Statement, TypeInfo,
Expand Down Expand Up @@ -794,3 +794,65 @@ async fn test_multiple_set_progress_handler_calls_drop_old_handler() -> anyhow::
assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}

#[sqlx_macros::test]
async fn test_query_with_update_hook() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;

// Using this string as a canary to ensure the callback doesn't get called with the wrong data pointer.
let state = format!("test");
conn.lock_handle()
.await?
.set_update_hook(move |operation, database, table, rowid| {
assert_eq!(state, "test");
assert_eq!(operation, SqliteOperation::Insert);
assert_eq!(database, "main");
assert_eq!(table, "tweet");
assert_eq!(rowid, 3);
});

let _ = sqlx::query("INSERT INTO tweet ( id, text ) VALUES ( 3, 'Hello, World' )")
.execute(&mut conn)
.await?;

Ok(())
}

#[sqlx_macros::test]
async fn test_multiple_set_update_hook_calls_drop_old_handler() -> anyhow::Result<()> {
let ref_counted_object = Arc::new(0);
assert_eq!(1, Arc::strong_count(&ref_counted_object));

{
let mut conn = new::<Sqlite>().await?;

let o = ref_counted_object.clone();
conn.lock_handle()
.await?
.set_update_hook(move |_, _, _, _| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle()
.await?
.set_update_hook(move |_, _, _, _| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

let o = ref_counted_object.clone();
conn.lock_handle()
.await?
.set_update_hook(move |_, _, _, _| {
println!("{o:?}");
});
assert_eq!(2, Arc::strong_count(&ref_counted_object));

conn.lock_handle().await?.remove_update_hook();
}

assert_eq!(1, Arc::strong_count(&ref_counted_object));
Ok(())
}