Skip to content

Implement support for calling Go functions from SQLite #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 16, 2015
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implement support for variadic functions.
Currently, the variadic part must all be the same type, because there's
no "generic" arg converter.
  • Loading branch information
danderson committed Aug 21, 2015
commit 566f63a43a314f8dcd758dba8c40dc11edc27a5e
52 changes: 42 additions & 10 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ type SQLiteRows struct {
}

type functionInfo struct {
f reflect.Value
argConverters []callbackArgConverter
retConverter callbackRetConverter
f reflect.Value
argConverters []callbackArgConverter
variadicConverter callbackArgConverter
retConverter callbackRetConverter
}

func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
Expand All @@ -178,7 +179,12 @@ func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {

func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
var args []reflect.Value
for i, arg := range argv {

if len(argv) < len(fi.argConverters) {
fi.error(ctx, fmt.Errorf("function requires at least %d arguments", len(fi.argConverters)))
}

for i, arg := range argv[:len(fi.argConverters)] {
v, err := fi.argConverters[i](arg)
if err != nil {
fi.error(ctx, err)
Expand All @@ -187,6 +193,17 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
args = append(args, v)
}

if fi.variadicConverter != nil {
for _, arg := range argv[len(fi.argConverters):] {
v, err := fi.variadicConverter(arg)
if err != nil {
fi.error(ctx, err)
return
}
args = append(args, v)
}
}

ret := fi.f.Call(args)

if len(ret) == 2 && ret[1].Interface() != nil {
Expand Down Expand Up @@ -218,7 +235,8 @@ func (tx *SQLiteTx) Rollback() error {
// The function can accept arguments of any real numeric type
// (i.e. not complex), as well as []byte and string. It must return a
// value of one of those types, and optionally an error as a second
// value.
// value. Variadic functions are allowed, if the variadic argument is
// one of the allowed types.
//
// If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive
Expand All @@ -230,24 +248,38 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
if t.Kind() != reflect.Func {
return errors.New("Non-function passed to RegisterFunc")
}
if t.IsVariadic() {
return errors.New("Variadic SQLite functions are not supported")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite functions must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}

for i := 0; i < t.NumIn(); i++ {
numArgs := t.NumIn()
if t.IsVariadic() {
numArgs--
}

for i := 0; i < numArgs; i++ {
conv, err := callbackArg(t.In(i))
if err != nil {
return err
}
fi.argConverters = append(fi.argConverters, conv)
}

if t.IsVariadic() {
conv, err := callbackArg(t.In(numArgs).Elem())
if err != nil {
return err
}
fi.variadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
numArgs = -1
}

conv, err := callbackRet(t.Out(0))
if err != nil {
return err
Expand All @@ -263,7 +295,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
rv := C.sqlite3_create_function_v2(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
}
Expand Down
13 changes: 13 additions & 0 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,13 @@ func TestFunctionRegistration(t *testing.T) {
regex := func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
variadic := func(a, b int64, c ...int64) int64 {
ret := a + b
for _, d := range c {
ret += d
}
return ret
}

sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
Expand Down Expand Up @@ -1098,6 +1105,9 @@ func TestFunctionRegistration(t *testing.T) {
if err := conn.RegisterFunc("regex", regex, true); err != nil {
return err
}
if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
return err
}
return nil
},
})
Expand All @@ -1121,6 +1131,9 @@ func TestFunctionRegistration(t *testing.T) {
{"SELECT not(0)", true},
{`SELECT regex("^foo.*", "foobar")`, true},
{`SELECT regex("^foo.*", "barfoobar")`, false},
{"SELECT variadic(1,2)", int64(3)},
{"SELECT variadic(1,2,3,4)", int64(10)},
{"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
}

for _, op := range ops {
Expand Down