Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/sqlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- uses: Swatinem/rust-cache@v1
with:
key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }}

- uses: actions-rs/cargo@v1
with:
command: check
Expand Down Expand Up @@ -144,6 +144,8 @@ jobs:
steps:
- uses: actions/checkout@v2

- run: mkdir /tmp/sqlite3-lib && wget -O /tmp/sqlite3-lib/ipaddr.so https://github.com/nalgeon/sqlean/releases/download/0.15.2/ipaddr.so

- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -164,6 +166,8 @@ jobs:
--test-threads=1
env:
DATABASE_URL: sqlite://tests/sqlite/sqlite.db
RUSTFLAGS: --cfg sqlite_ipaddr
LD_LIBRARY_PATH: /tmp/sqlite3-lib

postgres:
name: Postgres
Expand Down
128 changes: 125 additions & 3 deletions sqlx-core/src/sqlite/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,46 @@ use crate::error::Error;
use crate::sqlite::connection::handle::ConnectionHandle;
use crate::sqlite::connection::{ConnectionState, Statements};
use crate::sqlite::{SqliteConnectOptions, SqliteError};
use indexmap::IndexMap;
use libc::c_void;
use libsqlite3_sys::{
sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK,
sqlite3, sqlite3_busy_timeout, sqlite3_db_config, sqlite3_extended_result_codes, sqlite3_free,
sqlite3_load_extension, sqlite3_open_v2, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_OK,
SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX,
SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE,
};
use std::ffi::CString;
use std::ffi::{CStr, CString};
use std::io;
use std::ptr::{null, null_mut};
use std::os::raw::c_int;
use std::ptr::{addr_of_mut, null, null_mut};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;

static THREAD_ID: AtomicU64 = AtomicU64::new(0);

enum SqliteLoadExtensionMode {
/// Enables only the C-API, leaving the SQL function disabled.
Enable,
/// Disables both the C-API and the SQL function.
DisableAll,
}

impl SqliteLoadExtensionMode {
fn as_int(self) -> c_int {
match self {
SqliteLoadExtensionMode::Enable => 1,
SqliteLoadExtensionMode::DisableAll => 0,
}
}
}

pub struct EstablishParams {
filename: CString,
open_flags: i32,
busy_timeout: Duration,
statement_cache_capacity: usize,
log_settings: LogSettings,
extensions: IndexMap<CString, Option<CString>>,
pub(crate) thread_name: String,
pub(crate) command_channel_size: usize,
}
Expand Down Expand Up @@ -89,17 +110,67 @@ impl EstablishParams {
)
})?;

let extensions = options
.extensions
.iter()
.map(|(name, entry)| {
let entry = entry
.as_ref()
.map(|e| {
CString::new(e.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension entrypoint names passed to SQLite must not contain nul bytes"
)
})
})
.transpose()?;
Ok((
CString::new(name.as_bytes()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"extension names passed to SQLite must not contain nul bytes",
)
})?,
entry,
))
})
.collect::<Result<IndexMap<CString, Option<CString>>, io::Error>>()?;

Ok(Self {
filename,
open_flags: flags,
busy_timeout: options.busy_timeout,
statement_cache_capacity: options.statement_cache_capacity,
log_settings: options.log_settings.clone(),
extensions,
thread_name: (options.thread_name)(THREAD_ID.fetch_add(1, Ordering::AcqRel)),
command_channel_size: options.command_channel_size,
})
}

// Enable extension loading via the db_config function, as recommended by the docs rather
// than the more obvious `sqlite3_enable_load_extension`
// https://www.sqlite.org/c3ref/db_config.html
// https://www.sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension
unsafe fn sqlite3_set_load_extension(
db: *mut sqlite3,
mode: SqliteLoadExtensionMode,
) -> Result<(), Error> {
let status = sqlite3_db_config(
db,
SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
mode.as_int(),
null::<i32>(),
);

if status != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(db))));
}

Ok(())
}

pub(crate) fn establish(&self) -> Result<ConnectionState, Error> {
let mut handle = null_mut();

Expand Down Expand Up @@ -131,6 +202,57 @@ impl EstablishParams {
sqlite3_extended_result_codes(handle.as_ptr(), 1);
}

if !self.extensions.is_empty() {
// Enable loading extensions
unsafe {
Self::sqlite3_set_load_extension(handle.as_ptr(), SqliteLoadExtensionMode::Enable)?;
}

for ext in self.extensions.iter() {
// `sqlite3_load_extension` is unusual as it returns its errors via an out-pointer
// rather than by calling `sqlite3_errmsg`
let mut error = null_mut();
status = unsafe {
sqlite3_load_extension(
handle.as_ptr(),
ext.0.as_ptr(),
ext.1.as_ref().map_or(null(), |e| e.as_ptr()),
addr_of_mut!(error),
)
};

if status != SQLITE_OK {
// SAFETY: We become responsible for any memory allocation at `&error`, so test
// for null and take an RAII version for returns
let err_msg = if !error.is_null() {
unsafe {
let e = CStr::from_ptr(error).into();
sqlite3_free(error as *mut c_void);
e
}
} else {
CString::new("Unknown error when loading extension")
.expect("text should be representable as a CString")
};
return Err(Error::Database(Box::new(SqliteError::extension(
handle.as_ptr(),
&err_msg,
))));
}
}

// Preempt any hypothetical security issues arising from leaving ENABLE_LOAD_EXTENSION
// on by disabling the flag again once we've loaded all the requested modules.
// Fail-fast (via `?`) if disabling the extension loader didn't work for some reason,
// avoids an unexpected state going undetected.
unsafe {
Self::sqlite3_set_load_extension(
handle.as_ptr(),
SqliteLoadExtensionMode::DisableAll,
)?;
}
}

// Configure a busy timeout
// This causes SQLite to automatically sleep in increasing intervals until the time
// when there is something locked during [sqlite3_step].
Expand Down
7 changes: 7 additions & 0 deletions sqlx-core/src/sqlite/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ impl SqliteError {
message: message.to_owned(),
}
}

/// For errors during extension load, the error message is supplied via a separate pointer
pub(crate) fn extension(handle: *mut sqlite3, error_msg: &CStr) -> Self {
let mut err = Self::new(handle);
err.message = unsafe { from_utf8_unchecked(error_msg.to_bytes()).to_owned() };
err
}
}

impl Display for SqliteError {
Expand Down
44 changes: 44 additions & 0 deletions sqlx-core/src/sqlite/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ pub struct SqliteConnectOptions {
pub(crate) vfs: Option<Cow<'static, str>>,

pub(crate) pragmas: IndexMap<Cow<'static, str>, Option<Cow<'static, str>>>,
/// Extensions are specified as a pair of <Extension Name : Optional Entry Point>, the majority
/// of SQLite extensions will use the default entry points specified in the docs, these should
/// be added to the map with a `None` value.
/// <https://www.sqlite.org/loadext.html#loading_an_extension>
pub(crate) extensions: IndexMap<Cow<'static, str>, Option<Cow<'static, str>>>,

pub(crate) command_channel_size: usize,
pub(crate) row_channel_size: usize,
Expand Down Expand Up @@ -174,6 +179,7 @@ impl SqliteConnectOptions {
immutable: false,
vfs: None,
pragmas,
extensions: Default::default(),
collations: Default::default(),
serialized: false,
thread_name: Arc::new(DebugFn(|id| format!("sqlx-sqlite-worker-{}", id))),
Expand Down Expand Up @@ -414,4 +420,42 @@ impl SqliteConnectOptions {
self.vfs = Some(vfs_name.into());
self
}

/// Load an [extension](https://www.sqlite.org/loadext.html) at run-time when the database connection
/// is established, using the default entry point.
///
/// Most common SQLite extensions can be loaded using this method, for extensions where you need
/// to specify the entry point, use [`extension_with_entrypoint`][`Self::extension_with_entrypoint`] instead.
///
/// Multiple extensions can be loaded by calling the method repeatedly on the options struct, they
/// will be loaded in the order they are added.
/// ```rust,no_run
/// # use sqlx_core::error::Error;
/// use std::str::FromStr;
/// use sqlx::sqlite::SqliteConnectOptions;
/// # fn options() -> Result<SqliteConnectOptions, Error> {
/// let options = SqliteConnectOptions::from_str("sqlite://data.db")?
/// .extension("vsv")
/// .extension("mod_spatialite");
/// # Ok(options)
/// # }
/// ```
pub fn extension(mut self, extension_name: impl Into<Cow<'static, str>>) -> Self {
self.extensions.insert(extension_name.into(), None);
self
}

/// Load an extension with a specified entry point.
///
/// Useful when using non-standard extensions, or when developing your own, the second argument
/// specifies where SQLite should expect to find the extension init routine.
pub fn extension_with_entrypoint(
mut self,
extension_name: impl Into<Cow<'static, str>>,
entry_point: impl Into<Cow<'static, str>>,
) -> Self {
self.extensions
.insert(extension_name.into(), Some(entry_point.into()));
self
}
}
15 changes: 15 additions & 0 deletions tests/sqlite/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,21 @@ async fn it_executes_with_pool() -> anyhow::Result<()> {
Ok(())
}

#[cfg(sqlite_ipaddr)]
#[sqlx_macros::test]
async fn it_opens_with_extension() -> anyhow::Result<()> {
use std::str::FromStr;

let opts = SqliteConnectOptions::from_str(&dotenvy::var("DATABASE_URL")?)?.extension("ipaddr");

let mut conn = SqliteConnection::connect_with(&opts).await?;
conn.execute("SELECT ipmasklen('192.168.16.12/24');")
.await?;
conn.close().await?;

Ok(())
}

#[sqlx_macros::test]
async fn it_opens_in_memory() -> anyhow::Result<()> {
// If the filename is ":memory:", then a private, temporary in-memory database
Expand Down
39 changes: 39 additions & 0 deletions tests/x.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sys
import time
import argparse
import platform
import urllib.request
from glob import glob
from docker import start_database

Expand All @@ -23,6 +25,36 @@
dir_tests = os.path.join(dir_workspace, "tests")


def maybe_fetch_sqlite_extension():
"""
For supported platforms, if we're testing SQLite and the file isn't
already present, grab a simple extension for testing.

Returns the extension name if it was downloaded successfully or `None` if not.
"""
BASE_URL = "https://github.com/nalgeon/sqlean/releases/download/0.15.2/"
if platform.system() == "Darwin":
if platform.machine() == "arm64":
download_url = BASE_URL + "/ipaddr.arm64.dylib"
filename = "ipaddr.dylib"
else:
download_url = BASE_URL + "/ipaddr.dylib"
filename = "ipaddr.dylib"
elif platform.system() == "Linux":
download_url = BASE_URL + "/ipaddr.so"
filename = "ipaddr.so"
else:
# Unsupported OS
return None

if not os.path.exists(filename):
content = urllib.request.urlopen(download_url).read()
with open(filename, "wb") as fd:
fd.write(content)

return filename.split(".")[0]


def run(command, comment=None, env=None, service=None, tag=None, args=None, database_url_args=None):
if argv.list_targets:
if tag:
Expand All @@ -41,6 +73,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data

environ = env or {}

if service == "sqlite":
if maybe_fetch_sqlite_extension() is not None:
if environ.get("RUSTFLAGS"):
environ["RUSTFLAGS"] += " --cfg sqlite_ipaddr"
else:
environ["RUSTFLAGS"] = "--cfg sqlite_ipaddr"

if service is not None:
database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", cwd=dir_tests)

Expand Down