From b4cf1a0bf0c60decd047c60a90341f4f442804b2 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 21 Jul 2021 19:07:23 -0700 Subject: [PATCH] fix(sqlite): run `sqlite3_reset()` in `StatementWorker` this avoids possible race conditions without using a mutex --- sqlx-core/src/sqlite/connection/executor.rs | 34 +++++---- sqlx-core/src/sqlite/statement/handle.rs | 83 ++------------------- sqlx-core/src/sqlite/statement/virtual.rs | 8 +- sqlx-core/src/sqlite/statement/worker.rs | 71 ++++++++++++++++-- tests/sqlite/sqlite.rs | 34 +++++++++ 5 files changed, 133 insertions(+), 97 deletions(-) diff --git a/sqlx-core/src/sqlite/connection/executor.rs b/sqlx-core/src/sqlite/connection/executor.rs index 45c4768b55..09bbc6cf5b 100644 --- a/sqlx-core/src/sqlite/connection/executor.rs +++ b/sqlx-core/src/sqlite/connection/executor.rs @@ -4,7 +4,7 @@ use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::logger::QueryLogger; use crate::sqlite::connection::describe::describe; -use crate::sqlite::statement::{StatementHandle, VirtualStatement}; +use crate::sqlite::statement::{StatementHandle, StatementWorker, VirtualStatement}; use crate::sqlite::{ Sqlite, SqliteArguments, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteStatement, SqliteTypeInfo, @@ -16,7 +16,8 @@ use libsqlite3_sys::sqlite3_last_insert_rowid; use std::borrow::Cow; use std::sync::Arc; -fn prepare<'a>( +async fn prepare<'a>( + worker: &mut StatementWorker, statements: &'a mut StatementCache, statement: &'a mut Option, query: &str, @@ -39,7 +40,7 @@ fn prepare<'a>( if exists { // as this statement has been executed before, we reset before continuing // this also causes any rows that are from the statement to be inflated - statement.reset(); + statement.reset(worker).await?; } Ok(statement) @@ -61,21 +62,25 @@ fn bind( /// A structure holding sqlite statement handle and resetting the /// statement when it is dropped. -struct StatementResetter { +struct StatementResetter<'a> { handle: Arc, + worker: &'a mut StatementWorker, } -impl StatementResetter { - fn new(handle: &Arc) -> Self { +impl<'a> StatementResetter<'a> { + fn new(worker: &'a mut StatementWorker, handle: &Arc) -> Self { Self { + worker, handle: Arc::clone(handle), } } } -impl Drop for StatementResetter { +impl Drop for StatementResetter<'_> { fn drop(&mut self) { - self.handle.reset(); + // this method is designed to eagerly send the reset command + // so we don't need to await or spawn it + let _ = self.worker.reset(&self.handle); } } @@ -105,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let stmt = prepare(statements, statement, sql, persistent)?; + let stmt = prepare(worker, statements, statement, sql, persistent).await?; // keep track of how many arguments we have bound let mut num_arguments = 0; @@ -115,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // is dropped. `StatementResetter` will reliably reset the // statement even if the stream returned from `fetch_many` // is dropped early. - let _resetter = StatementResetter::new(stmt); + let resetter = StatementResetter::new(worker, stmt); // bind values to the statement num_arguments += bind(stmt, &arguments, num_arguments)?; @@ -127,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // invoke [sqlite3_step] on the dedicated worker thread // this will move us forward one row or finish the statement - let s = worker.step(stmt).await?; + let s = resetter.worker.step(stmt).await?; match s { Either::Left(changes) => { @@ -190,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let virtual_stmt = prepare(statements, statement, sql, persistent)?; + let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?; // keep track of how many arguments we have bound let mut num_arguments = 0; @@ -218,7 +223,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { logger.increment_rows(); - virtual_stmt.reset(); + virtual_stmt.reset(worker).await?; return Ok(Some(row)); } } @@ -240,11 +245,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { handle: ref mut conn, ref mut statements, ref mut statement, + ref mut worker, .. } = self; // prepare statement object (or checkout from cache) - let statement = prepare(statements, statement, sql, true)?; + let statement = prepare(worker, statements, statement, sql, true).await?; let mut parameters = 0; let mut columns = None; diff --git a/sqlx-core/src/sqlite/statement/handle.rs b/sqlx-core/src/sqlite/statement/handle.rs index 8b5b9fc772..e796d48c5b 100644 --- a/sqlx-core/src/sqlite/statement/handle.rs +++ b/sqlx-core/src/sqlite/statement/handle.rs @@ -1,13 +1,11 @@ -use either::Either; use std::ffi::c_void; use std::ffi::CStr; -use std::hint; + use std::os::raw::{c_char, c_int}; use std::ptr; use std::ptr::NonNull; use std::slice::from_raw_parts; use std::str::{from_utf8, from_utf8_unchecked}; -use std::sync::atomic::{AtomicU8, Ordering}; use libsqlite3_sys::{ sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, @@ -27,7 +25,7 @@ use crate::sqlite::type_info::DataType; use crate::sqlite::{SqliteError, SqliteTypeInfo}; #[derive(Debug)] -pub(crate) struct StatementHandle(NonNull, Lock); +pub(crate) struct StatementHandle(NonNull); // access to SQLite3 statement handles are safe to send and share between threads // as long as the `sqlite3_step` call is serialized. @@ -37,7 +35,11 @@ unsafe impl Sync for StatementHandle {} impl StatementHandle { pub(super) fn new(ptr: NonNull) -> Self { - Self(ptr, Lock::new()) + Self(ptr) + } + + pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt { + self.0.as_ptr() } #[inline] @@ -288,41 +290,13 @@ impl StatementHandle { Ok(from_utf8(self.column_blob(index))?) } - pub(crate) fn step(&self) -> Result, Error> { - self.1.enter_step(); - - let status = unsafe { sqlite3_step(self.0.as_ptr()) }; - let result = match status { - SQLITE_ROW => Ok(Either::Right(())), - SQLITE_DONE => Ok(Either::Left(self.changes())), - _ => Err(self.last_error().into()), - }; - - if self.1.exit_step() { - unsafe { sqlite3_reset(self.0.as_ptr()) }; - self.1.exit_reset(); - } - - result - } - - pub(crate) fn reset(&self) { - if !self.1.enter_reset() { - // reset or step already in progress - return; - } - - unsafe { sqlite3_reset(self.0.as_ptr()) }; - - self.1.exit_reset(); - } - pub(crate) fn clear_bindings(&self) { unsafe { sqlite3_clear_bindings(self.0.as_ptr()) }; } } impl Drop for StatementHandle { fn drop(&mut self) { + // SAFETY: we have exclusive access to the `StatementHandle` here unsafe { // https://sqlite.org/c3ref/finalize.html let status = sqlite3_finalize(self.0.as_ptr()); @@ -338,44 +312,3 @@ impl Drop for StatementHandle { } } } - -const RESET: u8 = 0b0000_0001; -const STEP: u8 = 0b0000_0010; - -// Lock to synchronize calls to `step` and `reset`. -#[derive(Debug)] -struct Lock(AtomicU8); - -impl Lock { - fn new() -> Self { - Self(AtomicU8::new(0)) - } - - // If this returns `true` reset can be performed, otherwise reset must be delayed until the - // current step finishes and `exit_step` is called. - fn enter_reset(&self) -> bool { - self.0.fetch_or(RESET, Ordering::Acquire) == 0 - } - - fn exit_reset(&self) { - self.0.fetch_and(!RESET, Ordering::Release); - } - - fn enter_step(&self) { - // NOTE: spin loop should be fine here as we are only waiting for a `reset` to finish which - // should be quick. - while self - .0 - .compare_exchange(0, STEP, Ordering::Acquire, Ordering::Relaxed) - .is_err() - { - hint::spin_loop(); - } - } - - // If this returns `true` it means a previous attempt to reset was delayed and must be - // performed now (followed by `exit_reset`). - fn exit_step(&self) -> bool { - self.0.fetch_and(!STEP, Ordering::Release) & RESET != 0 - } -} diff --git a/sqlx-core/src/sqlite/statement/virtual.rs b/sqlx-core/src/sqlite/statement/virtual.rs index 85141337b5..3da6d33d64 100644 --- a/sqlx-core/src/sqlite/statement/virtual.rs +++ b/sqlx-core/src/sqlite/statement/virtual.rs @@ -3,7 +3,7 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::sqlite::connection::ConnectionHandle; -use crate::sqlite::statement::StatementHandle; +use crate::sqlite::statement::{StatementHandle, StatementWorker}; use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue}; use crate::HashMap; use bytes::{Buf, Bytes}; @@ -176,7 +176,7 @@ impl VirtualStatement { ))) } - pub(crate) fn reset(&mut self) { + pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> { self.index = 0; for (i, handle) in self.handles.iter().enumerate() { @@ -185,9 +185,11 @@ impl VirtualStatement { // Reset A Prepared Statement Object // https://www.sqlite.org/c3ref/reset.html // https://www.sqlite.org/c3ref/clear_bindings.html - handle.reset(); + worker.reset(handle).await?; handle.clear_bindings(); } + + Ok(()) } } diff --git a/sqlx-core/src/sqlite/statement/worker.rs b/sqlx-core/src/sqlite/statement/worker.rs index 60e44a1115..4888eac610 100644 --- a/sqlx-core/src/sqlite/statement/worker.rs +++ b/sqlx-core/src/sqlite/statement/worker.rs @@ -6,6 +6,9 @@ use futures_channel::oneshot; use std::sync::{Arc, Weak}; use std::thread; +use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW}; +use std::future::Future; + // Each SQLite connection has a dedicated thread. // TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce @@ -21,6 +24,10 @@ enum StatementWorkerCommand { statement: Weak, tx: oneshot::Sender, Error>>, }, + Reset { + statement: Weak, + tx: oneshot::Sender<()>, + }, } impl StatementWorker { @@ -31,13 +38,37 @@ impl StatementWorker { for cmd in rx { match cmd { StatementWorkerCommand::Step { statement, tx } => { - let resp = if let Some(statement) = statement.upgrade() { - statement.step() + let statement = if let Some(statement) = statement.upgrade() { + statement } else { - // Statement is already finalized. - Err(Error::WorkerCrashed) + // statement is already finalized, the sender shouldn't be expecting a response + continue; + }; + + // SAFETY: only the `StatementWorker` calls this function + let status = unsafe { sqlite3_step(statement.as_ptr()) }; + let result = match status { + SQLITE_ROW => Ok(Either::Right(())), + SQLITE_DONE => Ok(Either::Left(statement.changes())), + _ => Err(statement.last_error().into()), }; - let _ = tx.send(resp); + + let _ = tx.send(result); + } + StatementWorkerCommand::Reset { statement, tx } => { + if let Some(statement) = statement.upgrade() { + // SAFETY: this must be the only place we call `sqlite3_reset` + unsafe { sqlite3_reset(statement.as_ptr()) }; + + // `sqlite3_reset()` always returns either `SQLITE_OK` + // or the last error code for the statement, + // which should have already been handled; + // so it's assumed the return value is safe to ignore. + // + // https://www.sqlite.org/c3ref/reset.html + + let _ = tx.send(()); + } } } } @@ -61,4 +92,34 @@ impl StatementWorker { rx.await.map_err(|_| Error::WorkerCrashed)? } + + /// Send a command to the worker to execute `sqlite3_reset()` next. + /// + /// This method is written to execute the sending of the command eagerly so + /// you do not need to await the returned future unless you want to. + /// + /// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error + /// in the statement execution which should have already been handled from `step()`. + pub(crate) fn reset( + &mut self, + statement: &Arc, + ) -> impl Future> { + // execute the sending eagerly so we don't need to spawn the future + let (tx, rx) = oneshot::channel(); + + let send_res = self + .tx + .send(StatementWorkerCommand::Reset { + statement: Arc::downgrade(statement), + tx, + }) + .map_err(|_| Error::WorkerCrashed); + + async move { + send_res?; + + // wait for the response + rx.await.map_err(|_| Error::WorkerCrashed) + } + } } diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 1334e493c1..f762673773 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -536,3 +536,37 @@ async fn it_resets_prepared_statement_after_fetch_many() -> anyhow::Result<()> { Ok(()) } + +// https://github.com/launchbadge/sqlx/issues/1300 +#[sqlx_macros::test] +async fn concurrent_resets_dont_segfault() { + use sqlx::{sqlite::SqliteConnectOptions, ConnectOptions}; + use std::{str::FromStr, time::Duration}; + use tokio::time; + + let mut conn = SqliteConnectOptions::from_str(":memory:") + .unwrap() + .connect() + .await + .unwrap(); + + sqlx::query("CREATE TABLE stuff (name INTEGER, value INTEGER)") + .execute(&mut conn) + .await + .unwrap(); + + let task = tokio::spawn(async move { + for i in 0..1000 { + sqlx::query("INSERT INTO stuff (name, value) VALUES (?, ?)") + .bind(i) + .bind(0) + .execute(&mut conn) + .await + .unwrap(); + } + }); + + time::sleep(Duration::from_millis(1)).await; + + task.abort(); +}