Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions libsql/examples/udf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::sync::Arc;

use libsql::{Builder, ScalarFunctionDef};

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let db = Builder::new_local(":memory:").build().await?.connect()?;

db.create_scalar_function(ScalarFunctionDef {
name: "log".to_string(),
num_args: 1,
deterministic: false,
innocuous: true,
direct_only: false,
callback: Arc::new(|args| {
println!("Log from SQL: {:?}", args.first().unwrap());
Ok(libsql::Value::Null)
}),
})?;

db.query("select log('hello world')", ()).await?;

Ok(())
}
10 changes: 10 additions & 0 deletions libsql/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::params::{IntoParams, Params};
use crate::rows::Rows;
use crate::statement::Statement;
use crate::transaction::Transaction;
use crate::udf::ScalarFunctionDef;
use crate::{Result, TransactionBehavior};

pub type AuthHook = Arc<dyn Fn(&AuthContext) -> Authorization>;
Expand Down Expand Up @@ -58,6 +59,10 @@ pub(crate) trait Conn {
fn authorizer(&self, _hook: Option<AuthHook>) -> Result<()> {
Err(crate::Error::AuthorizerNotSupported)
}

fn create_scalar_function(&self, _def: ScalarFunctionDef) -> Result<()> {
Err(crate::Error::UserDefinedFunctionsNotSupported)
}
}

/// A set of rows returned from `execute_batch`/`execute_transactional_batch`. It is essentially
Expand Down Expand Up @@ -285,6 +290,11 @@ impl Connection {
pub fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
self.conn.authorizer(hook)
}

/// Create a user-defined scalar function that can be called from SQL.
pub fn create_scalar_function(&self, def: ScalarFunctionDef) -> Result<()> {
self.conn.create_scalar_function(def)
}
}

impl fmt::Debug for Connection {
Expand Down
2 changes: 2 additions & 0 deletions libsql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub enum Error {
LoadExtensionNotSupported, // Not in rusqlite
#[error("Authorizer is only supported in local databases.")]
AuthorizerNotSupported, // Not in rusqlite
#[error("User defined functions are only supported in local databases.")]
UserDefinedFunctionsNotSupported, // Not in rusqlite
#[error("Column not found: {0}")]
ColumnNotFound(i32), // Not in rusqlite
#[error("Hrana: `{0}`")]
Expand Down
2 changes: 2 additions & 0 deletions libsql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ mod auth;
mod connection;
mod database;
mod load_extension_guard;
mod udf;

cfg_parser! {
mod parser;
Expand Down Expand Up @@ -186,6 +187,7 @@ pub use self::{
rows::{Column, Row, Rows},
statement::Statement,
transaction::{Transaction, TransactionBehavior},
udf::ScalarFunctionDef,
};

/// Convenient alias for `Result` using the `libsql::Error` type.
Expand Down
77 changes: 75 additions & 2 deletions libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use crate::auth::{AuthAction, AuthContext, Authorization};
use crate::connection::AuthHook;
use crate::local::rows::BatchedRows;
use crate::params::Params;
use crate::udf::{ScalarFunctionCallback, ScalarFunctionDef};
use crate::{connection::BatchRows, errors};
use crate::{TransactionBehavior, Value};
use std::ffi::CString;
use std::time::Duration;

use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction};

use crate::TransactionBehavior;

use libsql_sys::ffi;
use parking_lot::RwLock;
use std::{ffi::c_int, fmt, path::Path, sync::Arc};
Expand Down Expand Up @@ -494,6 +495,28 @@ impl Connection {
Ok(())
}

pub(crate) fn create_scalar_function(&self, def: ScalarFunctionDef) -> Result<()> {
let userdata = Box::into_raw(Box::new(Arc::into_raw(def.callback)));
let userdata_c = userdata as *mut ::std::os::raw::c_void;

let name = CString::new(def.name).unwrap();
unsafe {
ffi::sqlite3_create_function_v2(
self.raw,
name.as_ptr(),
def.num_args,
ffi::SQLITE_UTF8,
userdata_c,
Some(scalar_function_callback),
None,
None,
Some(drop_scalar_function_callback),
);
}

Ok(())
}

pub(crate) fn wal_checkpoint(&self, truncate: bool) -> Result<()> {
let mut pn_log = 0i32;
let mut pn_ckpt = 0i32;
Expand Down Expand Up @@ -666,6 +689,56 @@ impl Connection {
}
}

unsafe extern "C" fn scalar_function_callback(
context: *mut ffi::sqlite3_context,
argc: i32,
args: *mut *mut ffi::sqlite3_value,
) {
let callback = Box::from_raw(ffi::sqlite3_user_data(context) as *mut ScalarFunctionCallback);

let values = (0..argc)
.map(|i| {
let arg_ptr = *args.add(i as usize);
Value::from(libsql_sys::Value { raw_value: arg_ptr })
})
.collect::<Vec<_>>();

let result = (callback)(values);
std::mem::forget(callback);

match result {
Ok(value) => match value {
Value::Null => ffi::sqlite3_result_null(context),
Value::Integer(i) => ffi::sqlite3_result_int64(context, i),
Value::Real(d) => ffi::sqlite3_result_double(context, d),
Value::Text(t) => {
ffi::sqlite3_result_text(
context,
t.as_ptr() as *const i8,
t.len() as i32,
ffi::SQLITE_TRANSIENT(),
);
}
Value::Blob(b) => {
ffi::sqlite3_result_blob(
context,
b.as_ptr() as *const ::std::os::raw::c_void,
b.len() as i32,
ffi::SQLITE_TRANSIENT(),
);
}
},
Err(e) => {
let e_msg = e.to_string();
ffi::sqlite3_result_error(context, e_msg.as_ptr() as *const i8, e_msg.len() as i32);
}
}
}

unsafe extern "C" fn drop_scalar_function_callback(userdata: *mut ::std::os::raw::c_void) {
drop(Box::from_raw(userdata as *mut ScalarFunctionCallback));
}

unsafe extern "C" fn authorizer_callback(
user_data: *mut ::std::os::raw::c_void,
code: ::std::os::raw::c_int,
Expand Down
10 changes: 7 additions & 3 deletions libsql/src/local/impls.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;
use std::{fmt, path::Path};
use std::time::Duration;
use std::{fmt, path::Path};

use crate::connection::BatchRows;
use crate::{
Expand All @@ -9,8 +9,8 @@ use crate::{
rows::{ColumnsInner, RowInner, RowsInner},
statement::Stmt,
transaction::Tx,
Column, Connection, Result, Row, Rows, Statement, Transaction, TransactionBehavior, Value,
ValueType,
Column, Connection, Result, Row, Rows, ScalarFunctionDef, Statement, Transaction,
TransactionBehavior, Value, ValueType,
};

#[derive(Clone)]
Expand Down Expand Up @@ -100,6 +100,10 @@ impl Conn for LibsqlConnection {
fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
self.conn.authorizer(hook)
}

fn create_scalar_function(&self, def: ScalarFunctionDef) -> Result<()> {
self.conn.create_scalar_function(def)
}
}

impl Drop for LibsqlConnection {
Expand Down
55 changes: 55 additions & 0 deletions libsql/src/udf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use std::sync::Arc;

use crate::Value;

/// A Rust callback implementing a user-defined scalar SQL function.
pub type ScalarFunctionCallback = Arc<dyn Fn(Vec<Value>) -> anyhow::Result<Value>>;

/// A scalar user-defined SQL function definition.
pub struct ScalarFunctionDef {
/// The name of the SQL function to be created or redefined. The length of the name is limited
/// to 255 bytes. Note that the name length limit is in UTF-8 bytes, not characters. Any attempt
/// to create a function with a longer name will result in a SQLite misuse error.
pub name: String,
/// The number of arguments that the SQL function or aggregate takes. If this parameter is -1,
/// then the SQL function or aggregate may take any number of arguments between 0 and the limit
/// set by sqlite3_limit(SQLITE_LIMIT_FUNCTION_ARG). If the third parameter is less than -1 or
/// greater than 127 then the behavior is undefined.
pub num_args: i32,
/// Set to true to signal that the function will always return the same result given the same
/// inputs within a single SQL statement. Most SQL functions are deterministic. The built-in
/// random() SQL function is an example of a function that is not deterministic. The SQLite query
/// planner is able to perform additional optimizations on deterministic functions, so use of the
/// deterministic flag is recommended where possible.
pub deterministic: bool,
/// The `innocuous` flag means that the function is unlikely to cause problems even if misused.
/// An innocuous function should have no side effects and should not depend on any values other
/// than its input parameters. The `abs()` function is an example of an innocuous function. The
/// load_extension() SQL function is not innocuous because of its side effects.
///
/// `innocuous` is similar to `deterministic`, but is not exactly the same. The random()
/// function is an example of a function that is innocuous but not deterministic.
///
/// Some heightened security settings (SQLITE_DBCONFIG_TRUSTED_SCHEMA and PRAGMA
/// trusted_schema=OFF) disable the use of SQL functions inside views and triggers and in schema
/// structures such as CHECK constraints, DEFAULT clauses, expression indexes, partial indexes,
/// and generated columns unless the function is tagged with `innocuous`. Most built-in
/// functions are innocuous. Developers are advised to avoid using the `innocuous` flag for
/// application-defined functions unless the function has been carefully audited and found to be
/// free of potentially security-adverse side-effects and information-leaks.
pub innocuous: bool,
/// When set, prevents the function from being invoked from within VIEWs, TRIGGERs, CHECK
/// constraints, generated column expressions, index expressions, or the WHERE clause of partial
/// indexes.
///
/// For best security, the `direct_only` flag is recommended for all application-defined SQL
/// functions that do not need to be used inside of triggers, views, CHECK constraints, or other
/// elements of the database schema. This flag is especially recommended for SQL functions that
/// have side effects or reveal internal application state. Without this flag, an attacker might
/// be able to modify the schema of a database file to include invocations of the function with
/// parameters chosen by the attacker, which the application will then execute when the database
/// file is opened and read.
pub direct_only: bool,
/// The Rust callback that will be called to implement the function.
pub callback: ScalarFunctionCallback,
}
Loading