Skip to content

Commit 0cc94c9

Browse files
committed
feat: expose Rust API for creating user-defined scalar functions for local databases in libsql.
1 parent d4026bc commit 0cc94c9

File tree

7 files changed

+175
-5
lines changed

7 files changed

+175
-5
lines changed

libsql/examples/udf.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use std::sync::Arc;
2+
3+
use libsql::{Builder, ScalarFunctionDef};
4+
5+
#[tokio::main]
6+
async fn main() -> anyhow::Result<()> {
7+
let db = Builder::new_local(":memory:").build().await?.connect()?;
8+
9+
db.create_scalar_function(ScalarFunctionDef {
10+
name: "log".to_string(),
11+
num_args: 1,
12+
deterministic: false,
13+
innocuous: true,
14+
direct_only: false,
15+
callback: Arc::new(|args| {
16+
println!("Log from SQL: {:?}", args.first().unwrap());
17+
Ok(libsql::Value::Null)
18+
}),
19+
})?;
20+
21+
db.query("select log('hello world')", ()).await?;
22+
23+
Ok(())
24+
}

libsql/src/connection.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::params::{IntoParams, Params};
99
use crate::rows::Rows;
1010
use crate::statement::Statement;
1111
use crate::transaction::Transaction;
12+
use crate::udf::ScalarFunctionDef;
1213
use crate::{Result, TransactionBehavior};
1314

1415
pub type AuthHook = Arc<dyn Fn(&AuthContext) -> Authorization>;
@@ -58,6 +59,10 @@ pub(crate) trait Conn {
5859
fn authorizer(&self, _hook: Option<AuthHook>) -> Result<()> {
5960
Err(crate::Error::AuthorizerNotSupported)
6061
}
62+
63+
fn create_scalar_function(&self, _def: ScalarFunctionDef) -> Result<()> {
64+
Err(crate::Error::UserDefinedFunctionsNotSupported)
65+
}
6166
}
6267

6368
/// A set of rows returned from `execute_batch`/`execute_transactional_batch`. It is essentially
@@ -285,6 +290,11 @@ impl Connection {
285290
pub fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
286291
self.conn.authorizer(hook)
287292
}
293+
294+
/// Create a user-defined scalar function that can be called from SQL.
295+
pub fn create_scalar_function(&self, def: ScalarFunctionDef) -> Result<()> {
296+
self.conn.create_scalar_function(def)
297+
}
288298
}
289299

290300
impl fmt::Debug for Connection {

libsql/src/errors.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ pub enum Error {
2323
LoadExtensionNotSupported, // Not in rusqlite
2424
#[error("Authorizer is only supported in local databases.")]
2525
AuthorizerNotSupported, // Not in rusqlite
26+
#[error("User defined functions are only supported in local databases.")]
27+
UserDefinedFunctionsNotSupported, // Not in rusqlite
2628
#[error("Column not found: {0}")]
2729
ColumnNotFound(i32), // Not in rusqlite
2830
#[error("Hrana: `{0}`")]

libsql/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ mod auth;
159159
mod connection;
160160
mod database;
161161
mod load_extension_guard;
162+
mod udf;
162163

163164
cfg_parser! {
164165
mod parser;
@@ -186,6 +187,7 @@ pub use self::{
186187
rows::{Column, Row, Rows},
187188
statement::Statement,
188189
transaction::{Transaction, TransactionBehavior},
190+
udf::ScalarFunctionDef,
189191
};
190192

191193
/// Convenient alias for `Result` using the `libsql::Error` type.

libsql/src/local/connection.rs

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ use crate::auth::{AuthAction, AuthContext, Authorization};
44
use crate::connection::AuthHook;
55
use crate::local::rows::BatchedRows;
66
use crate::params::Params;
7+
use crate::udf::{ScalarFunctionCallback, ScalarFunctionDef};
78
use crate::{connection::BatchRows, errors};
9+
use crate::{TransactionBehavior, Value};
10+
use std::ffi::CString;
811
use std::time::Duration;
912

1013
use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction};
1114

12-
use crate::TransactionBehavior;
13-
1415
use libsql_sys::ffi;
1516
use parking_lot::RwLock;
1617
use std::{ffi::c_int, fmt, path::Path, sync::Arc};
@@ -494,6 +495,28 @@ impl Connection {
494495
Ok(())
495496
}
496497

498+
pub(crate) fn create_scalar_function(&self, def: ScalarFunctionDef) -> Result<()> {
499+
let userdata = Box::into_raw(Box::new(Arc::into_raw(def.callback)));
500+
let userdata_c = userdata as *mut ::std::os::raw::c_void;
501+
502+
let name = CString::new(def.name).unwrap();
503+
unsafe {
504+
ffi::sqlite3_create_function_v2(
505+
self.raw,
506+
name.as_ptr(),
507+
def.num_args,
508+
ffi::SQLITE_UTF8,
509+
userdata_c,
510+
Some(scalar_function_callback),
511+
None,
512+
None,
513+
Some(drop_scalar_function_callback),
514+
);
515+
}
516+
517+
Ok(())
518+
}
519+
497520
pub(crate) fn wal_checkpoint(&self, truncate: bool) -> Result<()> {
498521
let mut pn_log = 0i32;
499522
let mut pn_ckpt = 0i32;
@@ -666,6 +689,56 @@ impl Connection {
666689
}
667690
}
668691

692+
unsafe extern "C" fn scalar_function_callback(
693+
context: *mut ffi::sqlite3_context,
694+
argc: i32,
695+
args: *mut *mut ffi::sqlite3_value,
696+
) {
697+
let callback = Box::from_raw(ffi::sqlite3_user_data(context) as *mut ScalarFunctionCallback);
698+
699+
let values = (0..argc)
700+
.map(|i| {
701+
let arg_ptr = *args.add(i as usize);
702+
Value::from(libsql_sys::Value { raw_value: arg_ptr })
703+
})
704+
.collect::<Vec<_>>();
705+
706+
let result = (callback)(values);
707+
std::mem::forget(callback);
708+
709+
match result {
710+
Ok(value) => match value {
711+
Value::Null => ffi::sqlite3_result_null(context),
712+
Value::Integer(i) => ffi::sqlite3_result_int64(context, i),
713+
Value::Real(d) => ffi::sqlite3_result_double(context, d),
714+
Value::Text(t) => {
715+
ffi::sqlite3_result_text(
716+
context,
717+
t.as_ptr() as *const i8,
718+
t.len() as i32,
719+
ffi::SQLITE_TRANSIENT(),
720+
);
721+
}
722+
Value::Blob(b) => {
723+
ffi::sqlite3_result_blob(
724+
context,
725+
b.as_ptr() as *const ::std::os::raw::c_void,
726+
b.len() as i32,
727+
ffi::SQLITE_TRANSIENT(),
728+
);
729+
}
730+
},
731+
Err(e) => {
732+
let e_msg = e.to_string();
733+
ffi::sqlite3_result_error(context, e_msg.as_ptr() as *const i8, e_msg.len() as i32);
734+
}
735+
}
736+
}
737+
738+
unsafe extern "C" fn drop_scalar_function_callback(userdata: *mut ::std::os::raw::c_void) {
739+
drop(Box::from_raw(userdata as *mut ScalarFunctionCallback));
740+
}
741+
669742
unsafe extern "C" fn authorizer_callback(
670743
user_data: *mut ::std::os::raw::c_void,
671744
code: ::std::os::raw::c_int,

libsql/src/local/impls.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::Arc;
2-
use std::{fmt, path::Path};
32
use std::time::Duration;
3+
use std::{fmt, path::Path};
44

55
use crate::connection::BatchRows;
66
use crate::{
@@ -9,8 +9,8 @@ use crate::{
99
rows::{ColumnsInner, RowInner, RowsInner},
1010
statement::Stmt,
1111
transaction::Tx,
12-
Column, Connection, Result, Row, Rows, Statement, Transaction, TransactionBehavior, Value,
13-
ValueType,
12+
Column, Connection, Result, Row, Rows, ScalarFunctionDef, Statement, Transaction,
13+
TransactionBehavior, Value, ValueType,
1414
};
1515

1616
#[derive(Clone)]
@@ -100,6 +100,10 @@ impl Conn for LibsqlConnection {
100100
fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
101101
self.conn.authorizer(hook)
102102
}
103+
104+
fn create_scalar_function(&self, def: ScalarFunctionDef) -> Result<()> {
105+
self.conn.create_scalar_function(def)
106+
}
103107
}
104108

105109
impl Drop for LibsqlConnection {

libsql/src/udf.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
use std::sync::Arc;
2+
3+
use crate::Value;
4+
5+
/// A Rust callback implementing a user-defined scalar SQL function.
6+
pub type ScalarFunctionCallback = Arc<dyn Fn(Vec<Value>) -> anyhow::Result<Value>>;
7+
8+
/// A scalar user-defined SQL function definition.
9+
pub struct ScalarFunctionDef {
10+
/// The name of the SQL function to be created or redefined. The length of the name is limited
11+
/// to 255 bytes. Note that the name length limit is in UTF-8 bytes, not characters. Any attempt
12+
/// to create a function with a longer name will result in a SQLite misuse error.
13+
pub name: String,
14+
/// The number of arguments that the SQL function or aggregate takes. If this parameter is -1,
15+
/// then the SQL function or aggregate may take any number of arguments between 0 and the limit
16+
/// set by sqlite3_limit(SQLITE_LIMIT_FUNCTION_ARG). If the third parameter is less than -1 or
17+
/// greater than 127 then the behavior is undefined.
18+
pub num_args: i32,
19+
/// Set to true to signal that the function will always return the same result given the same
20+
/// inputs within a single SQL statement. Most SQL functions are deterministic. The built-in
21+
/// random() SQL function is an example of a function that is not deterministic. The SQLite query
22+
/// planner is able to perform additional optimizations on deterministic functions, so use of the
23+
/// deterministic flag is recommended where possible.
24+
pub deterministic: bool,
25+
/// The `innocuous` flag means that the function is unlikely to cause problems even if misused.
26+
/// An innocuous function should have no side effects and should not depend on any values other
27+
/// than its input parameters. The `abs()` function is an example of an innocuous function. The
28+
/// load_extension() SQL function is not innocuous because of its side effects.
29+
///
30+
/// `innocuous` is similar to `deterministic`, but is not exactly the same. The random()
31+
/// function is an example of a function that is innocuous but not deterministic.
32+
///
33+
/// Some heightened security settings (SQLITE_DBCONFIG_TRUSTED_SCHEMA and PRAGMA
34+
/// trusted_schema=OFF) disable the use of SQL functions inside views and triggers and in schema
35+
/// structures such as CHECK constraints, DEFAULT clauses, expression indexes, partial indexes,
36+
/// and generated columns unless the function is tagged with `innocuous`. Most built-in
37+
/// functions are innocuous. Developers are advised to avoid using the `innocuous` flag for
38+
/// application-defined functions unless the function has been carefully audited and found to be
39+
/// free of potentially security-adverse side-effects and information-leaks.
40+
pub innocuous: bool,
41+
/// When set, prevents the function from being invoked from within VIEWs, TRIGGERs, CHECK
42+
/// constraints, generated column expressions, index expressions, or the WHERE clause of partial
43+
/// indexes.
44+
///
45+
/// For best security, the `direct_only` flag is recommended for all application-defined SQL
46+
/// functions that do not need to be used inside of triggers, views, CHECK constraints, or other
47+
/// elements of the database schema. This flag is especially recommended for SQL functions that
48+
/// have side effects or reveal internal application state. Without this flag, an attacker might
49+
/// be able to modify the schema of a database file to include invocations of the function with
50+
/// parameters chosen by the attacker, which the application will then execute when the database
51+
/// file is opened and read.
52+
pub direct_only: bool,
53+
/// The Rust callback that will be called to implement the function.
54+
pub callback: ScalarFunctionCallback,
55+
}

0 commit comments

Comments
 (0)