diff --git a/Cargo.lock b/Cargo.lock index 5c46410509..adb5504c2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1205,6 +1205,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linked-hash-map" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dd5a6d5999d9907cda8ed67bbd137d3af8085216c2ac62de5be860bd41f304a" + [[package]] name = "lock_api" version = "0.3.4" @@ -1234,6 +1240,15 @@ dependencies = [ "scoped-tls 0.1.2", ] +[[package]] +name = "lru-cache" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "maplit" version = "1.0.2" @@ -2281,6 +2296,7 @@ dependencies = [ "libc", "libsqlite3-sys", "log", + "lru-cache", "md-5", "memchr", "num-bigint", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 0a5a692f23..d821976362 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -84,3 +84,4 @@ url = { version = "2.1.1", default-features = false } uuid = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] } whoami = "0.8.1" stringprep = "0.1.2" +lru-cache = "0.1.2" diff --git a/sqlx-core/src/common/mod.rs b/sqlx-core/src/common/mod.rs new file mode 100644 index 0000000000..f9698f28c2 --- /dev/null +++ b/sqlx-core/src/common/mod.rs @@ -0,0 +1,3 @@ +mod statement_cache; + +pub(crate) use statement_cache::StatementCache; diff --git a/sqlx-core/src/common/statement_cache.rs b/sqlx-core/src/common/statement_cache.rs new file mode 100644 index 0000000000..f0f108cf5f --- /dev/null +++ b/sqlx-core/src/common/statement_cache.rs @@ -0,0 +1,61 @@ +use lru_cache::LruCache; + +/// A cache for prepared statements. When full, the least recently used +/// statement gets removed. +#[derive(Debug)] +pub struct StatementCache { + inner: LruCache, +} + +impl StatementCache { + /// Create a new cache with the given capacity. + pub fn new(capacity: usize) -> Self { + Self { + inner: LruCache::new(capacity), + } + } + + /// Returns a mutable reference to the value corresponding to the given key + /// in the cache, if any. + pub fn get_mut(&mut self, k: &str) -> Option<&mut T> { + self.inner.get_mut(k) + } + + /// Inserts a new statement to the cache, returning the least recently used + /// statement id if the cache is full, or if inserting with an existing key, + /// the replaced existing statement. + pub fn insert(&mut self, k: &str, v: T) -> Option { + let mut lru_item = None; + + if self.inner.capacity() == self.len() && !self.inner.contains_key(k) { + lru_item = self.remove_lru(); + } else if self.contains_key(k) { + lru_item = self.inner.remove(k); + } + + self.inner.insert(k.into(), v); + + lru_item + } + + /// The number of statements in the cache. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Removes the least recently used item from the cache. + pub fn remove_lru(&mut self) -> Option { + self.inner.remove_lru().map(|(_, v)| v) + } + + /// Clear all cached statements from the cache. + #[cfg(any(feature = "sqlite"))] + pub fn clear(&mut self) { + self.inner.clear(); + } + + /// True if cache has a value for the given key. + pub fn contains_key(&mut self, k: &str) -> bool { + self.inner.contains_key(k) + } +} diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index e6eed54fbf..ae1a0a3b4a 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use futures_core::future::BoxFuture; use futures_core::Future; -use crate::database::Database; +use crate::database::{Database, HasStatementCache}; use crate::error::{BoxDynError, Error}; use crate::transaction::Transaction; @@ -64,6 +64,23 @@ pub trait Connection: Send { }) } + /// The number of statements currently cached in the connection. + fn cached_statements_size(&self) -> usize + where + Self::Database: HasStatementCache, + { + 0 + } + + /// Removes all statements from the cache, closing them on the server if + /// needed. + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> + where + Self::Database: HasStatementCache, + { + Box::pin(async move { Ok(()) }) + } + #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>>; diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index fc400d4b9d..388c45baac 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -74,3 +74,5 @@ pub trait HasArguments<'q> { /// The concrete type used as a buffer for arguments while encoding. type ArgumentBuffer: Default; } + +pub trait HasStatementCache {} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index cda49dbfde..118b42c8d9 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -37,6 +37,7 @@ pub mod transaction; #[macro_use] pub mod encode; +mod common; pub mod database; pub mod decode; pub mod describe; diff --git a/sqlx-core/src/mysql/connection/establish.rs b/sqlx-core/src/mysql/connection/establish.rs index 34f6bf6c2b..bda5a39ac9 100644 --- a/sqlx-core/src/mysql/connection/establish.rs +++ b/sqlx-core/src/mysql/connection/establish.rs @@ -1,6 +1,6 @@ use bytes::Bytes; -use hashbrown::HashMap; +use crate::common::StatementCache; use crate::error::Error; use crate::mysql::connection::{tls, MySqlStream, COLLATE_UTF8MB4_UNICODE_CI, MAX_PACKET_SIZE}; use crate::mysql::protocol::connect::{ @@ -98,7 +98,7 @@ impl MySqlConnection { Ok(Self { stream, - cache_statement: HashMap::new(), + cache_statement: StatementCache::new(options.statement_cache_capacity), scratch_row_columns: Default::default(), scratch_row_column_names: Default::default(), }) diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 5110f3c3f5..7b28f36013 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -15,7 +15,7 @@ use crate::mysql::connection::stream::Busy; use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::response::Status; use crate::mysql::protocol::statement::{ - BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, + BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, StmtClose, }; use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow}; use crate::mysql::protocol::Packet; @@ -26,8 +26,8 @@ use crate::mysql::{ impl MySqlConnection { async fn prepare(&mut self, query: &str) -> Result { - if let Some(&statement) = self.cache_statement.get(query) { - return Ok(statement); + if let Some(statement) = self.cache_statement.get_mut(query) { + return Ok(*statement); } // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html @@ -60,8 +60,10 @@ impl MySqlConnection { self.stream.maybe_recv_eof().await?; } - self.cache_statement - .insert(query.to_owned(), ok.statement_id); + // in case of the cache being full, close the least recently used statement + if let Some(statement) = self.cache_statement.insert(query, ok.statement_id) { + self.stream.send_packet(StmtClose { statement }).await?; + } Ok(ok.statement_id) } diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs index a3d1d35f6d..0e8ef38731 100644 --- a/sqlx-core/src/mysql/connection/mod.rs +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -6,10 +6,12 @@ use futures_core::future::BoxFuture; use futures_util::FutureExt; use hashbrown::HashMap; +use crate::common::StatementCache; use crate::connection::{Connect, Connection}; use crate::error::Error; use crate::executor::Executor; use crate::ext::ustr::UStr; +use crate::mysql::protocol::statement::StmtClose; use crate::mysql::protocol::text::{Ping, Quit}; use crate::mysql::row::MySqlColumn; use crate::mysql::{MySql, MySqlConnectOptions}; @@ -34,7 +36,7 @@ pub struct MySqlConnection { pub(crate) stream: MySqlStream, // cache by query string to the statement id - cache_statement: HashMap, + cache_statement: StatementCache, // working memory for the active row's column information // this allows us to re-use these allocations unless the user is persisting the @@ -75,6 +77,20 @@ impl Connection for MySqlConnection { self.stream.wait_until_ready().boxed() } + fn cached_statements_size(&self) -> usize { + self.cache_statement.len() + } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + while let Some(statement) = self.cache_statement.remove_lru() { + self.stream.send_packet(StmtClose { statement }).await?; + } + + Ok(()) + }) + } + #[doc(hidden)] fn should_flush(&self) -> bool { !self.stream.wbuf.is_empty() diff --git a/sqlx-core/src/mysql/database.rs b/sqlx-core/src/mysql/database.rs index ef97686578..6bb8083202 100644 --- a/sqlx-core/src/mysql/database.rs +++ b/sqlx-core/src/mysql/database.rs @@ -1,4 +1,4 @@ -use crate::database::{Database, HasArguments, HasValueRef}; +use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::mysql::value::{MySqlValue, MySqlValueRef}; use crate::mysql::{ MySqlArguments, MySqlConnection, MySqlRow, MySqlTransactionManager, MySqlTypeInfo, @@ -33,3 +33,5 @@ impl HasArguments<'_> for MySql { type ArgumentBuffer = Vec; } + +impl HasStatementCache for MySql {} diff --git a/sqlx-core/src/mysql/options.rs b/sqlx-core/src/mysql/options.rs index 3720f444bd..35eaad33a4 100644 --- a/sqlx-core/src/mysql/options.rs +++ b/sqlx-core/src/mysql/options.rs @@ -68,6 +68,14 @@ impl FromStr for MySqlSslMode { /// mysql://[host][/database][?properties] /// ``` /// +/// ## Properties +/// +/// |Parameter|Default|Description| +/// |---------|-------|-----------| +/// | `ssl-mode` | `PREFERRED` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`MySqlSslMode`]. | +/// | `ssl-ca` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. | +/// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. | +/// /// # Example /// /// ```rust,no_run @@ -92,6 +100,8 @@ impl FromStr for MySqlSslMode { /// # }) /// # } /// ``` +/// +/// [`MySqlSslMode`]: enum.MySqlSslMode.html #[derive(Debug, Clone)] pub struct MySqlConnectOptions { pub(crate) host: String, @@ -101,6 +111,7 @@ pub struct MySqlConnectOptions { pub(crate) database: Option, pub(crate) ssl_mode: MySqlSslMode, pub(crate) ssl_ca: Option, + pub(crate) statement_cache_capacity: usize, } impl Default for MySqlConnectOptions { @@ -120,6 +131,7 @@ impl MySqlConnectOptions { database: None, ssl_mode: MySqlSslMode::Preferred, ssl_ca: None, + statement_cache_capacity: 100, } } @@ -190,6 +202,17 @@ impl MySqlConnectOptions { self.ssl_ca = Some(file_name.as_ref().to_owned()); self } + + /// Sets the capacity of the connection's statement cache in a number of stored + /// distinct statements. Caching is handled using LRU, meaning when the + /// amount of queries hits the defined limit, the oldest statement will get + /// dropped. + /// + /// The default cache capacity is 100 statements. + pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { + self.statement_cache_capacity = capacity; + self + } } impl FromStr for MySqlConnectOptions { @@ -231,6 +254,10 @@ impl FromStr for MySqlConnectOptions { options = options.ssl_ca(&*value); } + "statement-cache-capacity" => { + options = options.statement_cache_capacity(value.parse()?); + } + _ => {} } } diff --git a/sqlx-core/src/mysql/protocol/statement/mod.rs b/sqlx-core/src/mysql/protocol/statement/mod.rs index 5ad292f560..9ae6b3c909 100644 --- a/sqlx-core/src/mysql/protocol/statement/mod.rs +++ b/sqlx-core/src/mysql/protocol/statement/mod.rs @@ -2,8 +2,10 @@ mod execute; mod prepare; mod prepare_ok; mod row; +mod stmt_close; pub(crate) use execute::Execute; pub(crate) use prepare::Prepare; pub(crate) use prepare_ok::PrepareOk; pub(crate) use row::BinaryRow; +pub(crate) use stmt_close::StmtClose; diff --git a/sqlx-core/src/mysql/protocol/statement/stmt_close.rs b/sqlx-core/src/mysql/protocol/statement/stmt_close.rs new file mode 100644 index 0000000000..13f095f9b5 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/statement/stmt_close.rs @@ -0,0 +1,16 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-stmt-close.html + +#[derive(Debug)] +pub struct StmtClose { + pub statement: u32, +} + +impl Encode<'_, Capabilities> for StmtClose { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x19); // COM_STMT_CLOSE + buf.extend(&self.statement.to_le_bytes()); + } +} diff --git a/sqlx-core/src/postgres/connection/establish.rs b/sqlx-core/src/postgres/connection/establish.rs index 7b160ad16d..84219b410c 100644 --- a/sqlx-core/src/postgres/connection/establish.rs +++ b/sqlx-core/src/postgres/connection/establish.rs @@ -1,5 +1,6 @@ use hashbrown::HashMap; +use crate::common::StatementCache; use crate::error::Error; use crate::io::Decode; use crate::postgres::connection::{sasl, stream::PgStream, tls}; @@ -138,7 +139,7 @@ impl PgConnection { transaction_status, pending_ready_for_query_count: 0, next_statement_id: 1, - cache_statement: HashMap::with_capacity(10), + cache_statement: StatementCache::new(options.statement_cache_capacity), cache_type_oid: HashMap::new(), cache_type_info: HashMap::new(), scratch_row_columns: Default::default(), diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index ed58ed56fa..e18a9c7217 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -9,8 +9,8 @@ use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::postgres::message::{ - self, Bind, CommandComplete, DataRow, Flush, MessageFormat, ParameterDescription, Parse, Query, - RowDescription, + self, Bind, Close, CommandComplete, DataRow, Flush, MessageFormat, ParameterDescription, Parse, + Query, RowDescription, }; use crate::postgres::type_info::PgType; use crate::postgres::{PgArguments, PgConnection, PgRow, PgValueFormat, Postgres}; @@ -88,15 +88,20 @@ async fn recv_desc_rows(conn: &mut PgConnection) -> Result Result { - if let Some(statement) = self.cache_statement.get(query) { + if let Some(statement) = self.cache_statement.get_mut(query) { return Ok(*statement); } let statement = prepare(self, query, arguments).await?; - self.cache_statement.insert(query.to_owned(), statement); + if let Some(statement) = self.cache_statement.insert(query, statement) { + self.stream.write(Close::Statement(statement)); + self.stream.write(Flush); + self.stream.flush().await?; + } Ok(statement) } diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index e3d1e22066..885ba92f56 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -5,6 +5,7 @@ use futures_core::future::BoxFuture; use futures_util::{FutureExt, TryFutureExt}; use hashbrown::HashMap; +use crate::common::StatementCache; use crate::connection::{Connect, Connection}; use crate::error::Error; use crate::executor::Executor; @@ -12,7 +13,7 @@ use crate::ext::ustr::UStr; use crate::io::Decode; use crate::postgres::connection::stream::PgStream; use crate::postgres::message::{ - Message, MessageFormat, ReadyForQuery, Terminate, TransactionStatus, + Close, Flush, Message, MessageFormat, ReadyForQuery, Terminate, TransactionStatus, }; use crate::postgres::row::PgColumn; use crate::postgres::{PgConnectOptions, PgTypeInfo, Postgres}; @@ -46,7 +47,7 @@ pub struct PgConnection { next_statement_id: u32, // cache statement by query string to the id and columns - cache_statement: HashMap, + cache_statement: StatementCache, // cache user-defined types by id <-> info cache_type_info: HashMap, @@ -119,6 +120,28 @@ impl Connection for PgConnection { self.execute("/* SQLx ping */").map_ok(|_| ()).boxed() } + fn cached_statements_size(&self) -> usize { + self.cache_statement.len() + } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + let mut needs_flush = false; + + while let Some(statement) = self.cache_statement.remove_lru() { + self.stream.write(Close::Statement(statement)); + needs_flush = true; + } + + if needs_flush { + self.stream.write(Flush); + self.stream.flush().await?; + } + + Ok(()) + }) + } + #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { self.wait_until_ready().boxed() diff --git a/sqlx-core/src/postgres/database.rs b/sqlx-core/src/postgres/database.rs index 8b1614120e..8b3af756e6 100644 --- a/sqlx-core/src/postgres/database.rs +++ b/sqlx-core/src/postgres/database.rs @@ -1,4 +1,4 @@ -use crate::database::{Database, HasArguments, HasValueRef}; +use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::postgres::arguments::PgArgumentBuffer; use crate::postgres::value::{PgValue, PgValueRef}; use crate::postgres::{PgArguments, PgConnection, PgRow, PgTransactionManager, PgTypeInfo}; @@ -32,3 +32,5 @@ impl HasArguments<'_> for Postgres { type ArgumentBuffer = PgArgumentBuffer; } + +impl HasStatementCache for Postgres {} diff --git a/sqlx-core/src/postgres/message/close.rs b/sqlx-core/src/postgres/message/close.rs new file mode 100644 index 0000000000..07e7795008 --- /dev/null +++ b/sqlx-core/src/postgres/message/close.rs @@ -0,0 +1,32 @@ +use crate::io::Encode; +use crate::postgres::io::PgBufMutExt; + +const CLOSE_PORTAL: u8 = b'P'; +const CLOSE_STATEMENT: u8 = b'S'; + +#[derive(Debug)] +#[allow(dead_code)] +pub enum Close { + Statement(u32), + Portal(u32), +} + +impl Encode<'_> for Close { + fn encode_with(&self, buf: &mut Vec, _: ()) { + // 15 bytes for 1-digit statement/portal IDs + buf.reserve(20); + buf.push(b'C'); + + buf.put_length_prefixed(|buf| match self { + Close::Statement(id) => { + buf.push(CLOSE_STATEMENT); + buf.put_statement_name(*id); + } + + Close::Portal(id) => { + buf.push(CLOSE_PORTAL); + buf.put_portal_name(Some(*id)); + } + }) + } +} diff --git a/sqlx-core/src/postgres/message/mod.rs b/sqlx-core/src/postgres/message/mod.rs index 7cb1eb49ea..87f11feb69 100644 --- a/sqlx-core/src/postgres/message/mod.rs +++ b/sqlx-core/src/postgres/message/mod.rs @@ -6,6 +6,7 @@ use crate::io::Decode; mod authentication; mod backend_key_data; mod bind; +mod close; mod command_complete; mod data_row; mod describe; @@ -28,6 +29,7 @@ mod terminate; pub use authentication::{Authentication, AuthenticationSasl}; pub use backend_key_data::BackendKeyData; pub use bind::Bind; +pub use close::Close; pub use command_complete::CommandComplete; pub use data_row::DataRow; pub use describe::Describe; diff --git a/sqlx-core/src/postgres/options.rs b/sqlx-core/src/postgres/options.rs index b59a332f53..d1640fa953 100644 --- a/sqlx-core/src/postgres/options.rs +++ b/sqlx-core/src/postgres/options.rs @@ -69,6 +69,15 @@ impl FromStr for PgSslMode { /// postgresql://[user[:password]@][host][:port][/dbname][?param1=value1&...] /// ``` /// +/// ## Parameters +/// +/// |Parameter|Default|Description| +/// |---------|-------|-----------| +/// | `sslmode` | `prefer` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`PgSqlSslMode`]. | +/// | `sslrootcert` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. | +/// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. | +/// +/// /// The URI scheme designator can be either `postgresql://` or `postgres://`. /// Each of the URI parts is optional. /// @@ -106,6 +115,8 @@ impl FromStr for PgSslMode { /// # }) /// # } /// ``` +/// +/// [`PgSqlSslMode`]: enum.PgSslMode.html #[derive(Debug, Clone)] pub struct PgConnectOptions { pub(crate) host: String, @@ -115,6 +126,7 @@ pub struct PgConnectOptions { pub(crate) database: Option, pub(crate) ssl_mode: PgSslMode, pub(crate) ssl_root_cert: Option, + pub(crate) statement_cache_capacity: usize, } impl Default for PgConnectOptions { @@ -162,6 +174,7 @@ impl PgConnectOptions { .ok() .and_then(|v| v.parse().ok()) .unwrap_or_default(), + statement_cache_capacity: 100, } } @@ -285,6 +298,17 @@ impl PgConnectOptions { self.ssl_root_cert = Some(cert.as_ref().to_path_buf()); self } + + /// Sets the capacity of the connection's statement cache in a number of stored + /// distinct statements. Caching is handled using LRU, meaning when the + /// amount of queries hits the defined limit, the oldest statement will get + /// dropped. + /// + /// The default cache capacity is 100 statements. + pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { + self.statement_cache_capacity = capacity; + self + } } fn default_host(port: u16) -> String { @@ -345,6 +369,10 @@ impl FromStr for PgConnectOptions { options = options.ssl_root_cert(&*value); } + "statement-cache-capacity" => { + options = options.statement_cache_capacity(value.parse()?); + } + _ => {} } } diff --git a/sqlx-core/src/sqlite/connection/establish.rs b/sqlx-core/src/sqlite/connection/establish.rs index dbebd1652b..4c640bfb62 100644 --- a/sqlx-core/src/sqlite/connection/establish.rs +++ b/sqlx-core/src/sqlite/connection/establish.rs @@ -1,7 +1,6 @@ use std::io; use std::ptr::{null, null_mut}; -use hashbrown::HashMap; use libsqlite3_sys::{ sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_PRIVATECACHE, @@ -12,7 +11,10 @@ use sqlx_rt::blocking; use crate::error::Error; use crate::sqlite::connection::handle::ConnectionHandle; use crate::sqlite::statement::StatementWorker; -use crate::sqlite::{SqliteConnectOptions, SqliteConnection, SqliteError}; +use crate::{ + common::StatementCache, + sqlite::{SqliteConnectOptions, SqliteConnection, SqliteError}, +}; pub(super) async fn establish(options: &SqliteConnectOptions) -> Result { let mut filename = options @@ -87,7 +89,7 @@ pub(super) async fn establish(options: &SqliteConnectOptions) -> Result( conn: &mut ConnectionHandle, - statements: &'a mut HashMap, + statements: &'a mut StatementCache, statement: &'a mut Option, query: &str, persistent: bool, @@ -28,7 +29,7 @@ fn prepare<'a>( if !statements.contains_key(query) { let statement = SqliteStatement::prepare(conn, query, false)?; - statements.insert(query.to_owned(), statement); + statements.insert(query, statement); } let statement = statements.get_mut(query).unwrap(); diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index d70c1c8909..d768f14332 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -6,6 +6,7 @@ use futures_util::future; use hashbrown::HashMap; use libsqlite3_sys::sqlite3; +use crate::common::StatementCache; use crate::connection::{Connect, Connection}; use crate::error::Error; use crate::ext::ustr::UStr; @@ -25,7 +26,7 @@ pub struct SqliteConnection { pub(crate) worker: StatementWorker, // cache of semi-persistent statements - pub(crate) statements: HashMap, + pub(crate) statements: StatementCache, // most recent non-persistent statement pub(crate) statement: Option, @@ -60,6 +61,17 @@ impl Connection for SqliteConnection { Box::pin(future::ok(())) } + fn cached_statements_size(&self) -> usize { + self.statements.len() + } + + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + self.statements.clear(); + Ok(()) + }) + } + #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { // For SQLite, FLUSH does effectively nothing diff --git a/sqlx-core/src/sqlite/database.rs b/sqlx-core/src/sqlite/database.rs index 3d8ea27520..fa00660fbe 100644 --- a/sqlx-core/src/sqlite/database.rs +++ b/sqlx-core/src/sqlite/database.rs @@ -1,4 +1,4 @@ -use crate::database::{Database, HasArguments, HasValueRef}; +use crate::database::{Database, HasArguments, HasStatementCache, HasValueRef}; use crate::sqlite::{ SqliteArgumentValue, SqliteArguments, SqliteConnection, SqliteRow, SqliteTransactionManager, SqliteTypeInfo, SqliteValue, SqliteValueRef, @@ -33,3 +33,5 @@ impl<'q> HasArguments<'q> for Sqlite { type ArgumentBuffer = Vec>; } + +impl HasStatementCache for Sqlite {} diff --git a/sqlx-core/src/sqlite/options.rs b/sqlx-core/src/sqlite/options.rs index 5f925b6b39..42c9db977e 100644 --- a/sqlx-core/src/sqlite/options.rs +++ b/sqlx-core/src/sqlite/options.rs @@ -10,6 +10,7 @@ use crate::error::BoxDynError; pub struct SqliteConnectOptions { pub(crate) filename: PathBuf, pub(crate) in_memory: bool, + pub(crate) statement_cache_capacity: usize, } impl Default for SqliteConnectOptions { @@ -23,8 +24,20 @@ impl SqliteConnectOptions { Self { filename: PathBuf::from(":memory:"), in_memory: false, + statement_cache_capacity: 100, } } + + /// Sets the capacity of the connection's statement cache in a number of stored + /// distinct statements. Caching is handled using LRU, meaning when the + /// amount of queries hits the defined limit, the oldest statement will get + /// dropped. + /// + /// The default cache capacity is 100 statements. + pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { + self.statement_cache_capacity = capacity; + self + } } impl FromStr for SqliteConnectOptions { @@ -34,6 +47,7 @@ impl FromStr for SqliteConnectOptions { let mut options = Self { filename: PathBuf::new(), in_memory: false, + statement_cache_capacity: 100, }; // remove scheme diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index dfdf0cf5cc..becede928d 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -177,3 +177,25 @@ SELECT id, text FROM messages; Ok(()) } + +#[sqlx_macros::test] +async fn it_caches_statements() -> anyhow::Result<()> { + let mut conn = new::().await?; + + for i in 0..2 { + let row = sqlx::query("SELECT ? AS val") + .bind(i) + .fetch_one(&mut conn) + .await?; + + let val: u32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(1, conn.cached_statements_size()); + conn.clear_cached_statements().await?; + assert_eq!(0, conn.cached_statements_size()); + + Ok(()) +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 2aeea43dbc..e582b671bf 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -487,3 +487,25 @@ SELECT id, text FROM _sqlx_test_postgres_5112; Ok(()) } + +#[sqlx_macros::test] +async fn it_caches_statements() -> anyhow::Result<()> { + let mut conn = new::().await?; + + for i in 0..2 { + let row = sqlx::query("SELECT $1 AS val") + .bind(i) + .fetch_one(&mut conn) + .await?; + + let val: u32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(1, conn.cached_statements_size()); + conn.clear_cached_statements().await?; + assert_eq!(0, conn.cached_statements_size()); + + Ok(()) +} diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 1630773b0f..35db78dedc 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -269,3 +269,25 @@ SELECT id, text FROM _sqlx_test; Ok(()) } + +#[sqlx_macros::test] +async fn it_caches_statements() -> anyhow::Result<()> { + let mut conn = new::().await?; + + for i in 0..2 { + let row = sqlx::query("SELECT ? AS val") + .bind(i) + .fetch_one(&mut conn) + .await?; + + let val: i32 = row.get("val"); + + assert_eq!(i, val); + } + + assert_eq!(1, conn.cached_statements_size()); + conn.clear_cached_statements().await?; + assert_eq!(0, conn.cached_statements_size()); + + Ok(()) +}