-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: Allow custom window functions to be registered with the driver #1220
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,8 +159,25 @@ int _sqlite3_create_function( | |
return sqlite3_create_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xFunc, xStep, xFinal); | ||
} | ||
|
||
int _sqlite3_create_window_function( | ||
sqlite3 *db, | ||
const char *zFunctionName, | ||
int nArg, | ||
int eTextRep, | ||
uintptr_t pApp, | ||
void (*xStep)(sqlite3_context*,int,sqlite3_value**), | ||
void (*xFinal)(sqlite3_context*), | ||
void (*xValue)(sqlite3_context*), | ||
void (*xInverse)(sqlite3_context*,int,sqlite3_value**) | ||
) { | ||
return sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xStep, xFinal, xValue, xInverse, 0); | ||
} | ||
|
||
|
||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); | ||
void stepTrampoline(sqlite3_context*, int, sqlite3_value**); | ||
void valueTrampoline(sqlite3_context*); | ||
void inverseTrampoline(sqlite3_context*); | ||
void doneTrampoline(sqlite3_context*); | ||
|
||
int compareTrampoline(void*, int, char*, int, char*); | ||
|
@@ -438,10 +455,18 @@ type aggInfo struct { | |
active map[int64]reflect.Value | ||
next int64 | ||
|
||
nArgs int | ||
|
||
stepArgConverters []callbackArgConverter | ||
stepVariadicConverter callbackArgConverter | ||
|
||
doneRetConverter callbackRetConverter | ||
|
||
// Inverse and Value arg converters are used for window aggregations. | ||
inverseArgConverters []callbackArgConverter | ||
inverseVariadicConverter callbackArgConverter | ||
|
||
valueRetConverter callbackRetConverter | ||
} | ||
|
||
func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) { | ||
|
@@ -461,6 +486,8 @@ func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) { | |
return *aggIdx, ai.active[*aggIdx], nil | ||
} | ||
|
||
// Step Implements the xStep function for both aggregate and window functions | ||
// https://www.sqlite.org/windowfunctions.html#udfwinfunc | ||
func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { | ||
_, agg, err := ai.agg(ctx) | ||
if err != nil { | ||
|
@@ -481,6 +508,8 @@ func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { | |
} | ||
} | ||
|
||
// Done Implements the xFinal function for both aggregate and window functions | ||
// https://www.sqlite.org/windowfunctions.html#udfwinfunc | ||
func (ai *aggInfo) Done(ctx *C.sqlite3_context) { | ||
idx, agg, err := ai.agg(ctx) | ||
if err != nil { | ||
|
@@ -502,6 +531,49 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) { | |
} | ||
} | ||
|
||
// Inverse Implements the xInverse function for window functions | ||
// https://www.sqlite.org/windowfunctions.html#udfwinfunc | ||
func (ai *aggInfo) Inverse(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { | ||
_, agg, err := ai.agg(ctx) | ||
if err != nil { | ||
callbackError(ctx, err) | ||
return | ||
} | ||
|
||
args, err := callbackConvertArgs(argv, ai.inverseArgConverters, ai.inverseVariadicConverter) | ||
if err != nil { | ||
callbackError(ctx, err) | ||
return | ||
} | ||
|
||
ret := agg.MethodByName("Inverse").Call(args) | ||
if len(ret) == 1 && ret[0].Interface() != nil { | ||
callbackError(ctx, ret[0].Interface().(error)) | ||
return | ||
} | ||
} | ||
|
||
// Value Implements the xValue function for window functions | ||
// https://www.sqlite.org/windowfunctions.html#udfwinfunc | ||
func (ai *aggInfo) Value(ctx *C.sqlite3_context) { | ||
_, agg, err := ai.agg(ctx) | ||
if err != nil { | ||
callbackError(ctx, err) | ||
return | ||
} | ||
ret := agg.MethodByName("Value").Call(nil) | ||
if len(ret) == 2 && ret[1].Interface() != nil { | ||
callbackError(ctx, ret[1].Interface().(error)) | ||
return | ||
} | ||
|
||
err = ai.valueRetConverter(ctx, ret[0]) | ||
if err != nil { | ||
callbackError(ctx, err) | ||
return | ||
} | ||
} | ||
|
||
// Commit transaction. | ||
func (tx *SQLiteTx) Commit() error { | ||
_, err := tx.c.exec(context.Background(), "COMMIT", nil) | ||
|
@@ -684,20 +756,28 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe | |
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xFunc), (*[0]byte)(xStep), (*[0]byte)(xFinal)) | ||
} | ||
|
||
// RegisterAggregator makes a Go type available as a SQLite aggregation function. | ||
func sqlite3CreateWindowFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer, xValue unsafe.Pointer, xInverse unsafe.Pointer) C.int { | ||
return C._sqlite3_create_window_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(uintptr(pApp)), (*[0]byte)(xStep), (*[0]byte)(xFinal), (*[0]byte)(xValue), (*[0]byte)(xInverse)) | ||
} | ||
|
||
// RegisterAggregator makes a Go type available as a SQLite aggregation function or window function. | ||
// | ||
// Because aggregation is incremental, it's implemented in Go with a | ||
// type that has 2 methods: func Step(values) accumulates one row of | ||
// data into the accumulator, and func Done() ret finalizes and | ||
// returns the aggregate value. "values" and "ret" may be any type | ||
// supported by RegisterFunc. | ||
// | ||
// To register a window function, the type must also contain implement | ||
// a Value and Inverse function. | ||
// | ||
// RegisterAggregator takes as implementation a constructor function | ||
// that constructs an instance of the aggregator type each time an | ||
// aggregation begins. The constructor must return a pointer to a | ||
// type, or an interface that implements Step() and Done(). | ||
// type, or an interface that implements Step() and Done(), and optionally | ||
// Value() and Inverse() if the aggregator is a window function. | ||
// | ||
// The constructor function and the Step/Done methods may optionally | ||
// The constructor function and the Step/Done/Value/Inverse methods may optionally | ||
// return an error in addition to their other return values. | ||
// | ||
// See _example/go_custom_funcs for a detailed example. | ||
|
@@ -719,93 +799,142 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl any, pure bool) error | |
} | ||
|
||
agg := t.Out(0) | ||
var implReturnsPointer bool | ||
switch agg.Kind() { | ||
case reflect.Ptr, reflect.Interface: | ||
case reflect.Ptr: | ||
implReturnsPointer = true | ||
case reflect.Interface: | ||
implReturnsPointer = false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right that I can't conclude if it returns a pointer or not in this case. It does contradict the error message in the default case, but it was previously allowed in the initial implementation of user-defined functions #229. Perhaps there is a better name for this variable. |
||
default: | ||
return errors.New("SQlite aggregator constructor must return a pointer object") | ||
return errors.New("SQLite aggregator constructor must return a pointer object") | ||
} | ||
|
||
stepFn, found := agg.MethodByName("Step") | ||
if !found { | ||
return errors.New("SQlite aggregator doesn't have a Step() function") | ||
return errors.New("SQLite aggregator doesn't have a Step() function") | ||
} | ||
err := ai.setupStepInterface(stepFn, &ai.stepArgConverters, &ai.stepVariadicConverter, implReturnsPointer, "Step()") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't entirely understand this call. The penultimate parameter to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could rename the parameter to I maintained the behavior from the original implementation, where we'd only skip the method receiver if the According to the
When running |
||
if err != nil { | ||
return err | ||
} | ||
step := stepFn.Type | ||
if step.NumOut() != 0 && step.NumOut() != 1 { | ||
return errors.New("SQlite aggregator Step() function must return 0 or 1 values") | ||
|
||
doneFn, found := agg.MethodByName("Done") | ||
if !found { | ||
return errors.New("SQLite aggregator doesn't have a Done() function") | ||
} | ||
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { | ||
return errors.New("type of SQlite aggregator Step() return value must be error") | ||
err = ai.setupDoneInterface(doneFn, &ai.doneRetConverter, implReturnsPointer, "Done()") | ||
if err != nil { | ||
return err | ||
} | ||
|
||
stepNArgs := step.NumIn() | ||
valueFn, valueFnFound := agg.MethodByName("Value") | ||
inverseFn, inverseFnFound := agg.MethodByName("Inverse") | ||
if (inverseFnFound && !valueFnFound) || (valueFnFound && !inverseFnFound) { | ||
return errors.New("SQLite window aggregator must implement both Value() and Inverse() functions") | ||
} | ||
isWindowFunction := valueFnFound && inverseFnFound | ||
// Validate window function interface | ||
if isWindowFunction { | ||
if inverseFn.Type.NumIn() != stepFn.Type.NumIn() { | ||
return errors.New("SQLite window aggregator Inverse() function must accept the same number of arguments as Step()") | ||
} | ||
err := ai.setupStepInterface(inverseFn, &ai.inverseArgConverters, &ai.inverseVariadicConverter, implReturnsPointer, "Inverse()") | ||
if err != nil { | ||
return err | ||
} | ||
err = ai.setupDoneInterface(valueFn, &ai.valueRetConverter, implReturnsPointer, "Value()") | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
|
||
ai.active = make(map[int64]reflect.Value) | ||
ai.next = 1 | ||
|
||
// ai must outlast the database connection, or we'll have dangling pointers. | ||
c.aggregators = append(c.aggregators, &ai) | ||
|
||
cname := C.CString(name) | ||
defer C.free(unsafe.Pointer(cname)) | ||
opts := C.SQLITE_UTF8 | ||
if pure { | ||
opts |= C.SQLITE_DETERMINISTIC | ||
} | ||
var rv C.int | ||
if isWindowFunction { | ||
rv = sqlite3CreateWindowFunction(c.db, cname, C.int(ai.nArgs), C.int(opts), newHandle(c, &ai), C.stepTrampoline, C.doneTrampoline, C.valueTrampoline, C.inverseTrampoline) | ||
} else { | ||
rv = sqlite3CreateFunction(c.db, cname, C.int(ai.nArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) | ||
} | ||
if rv != C.SQLITE_OK { | ||
return c.lastError() | ||
} | ||
return nil | ||
} | ||
|
||
func (ai *aggInfo) setupStepInterface(fn reflect.Method, argConverters *[]callbackArgConverter, variadicConverter *callbackArgConverter, isImplPointer bool, name string) error { | ||
t := fn.Type | ||
if t.NumOut() != 0 && t.NumOut() != 1 { | ||
return fmt.Errorf("SQLite aggregator %s function must return 0 or 1 values", name) | ||
} | ||
if t.NumOut() == 1 && !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { | ||
return fmt.Errorf("type of SQLite aggregator %s return value must be error", name) | ||
} | ||
nArgs := t.NumIn() | ||
start := 0 | ||
if agg.Kind() == reflect.Ptr { | ||
if isImplPointer { | ||
// Skip over the method receiver | ||
stepNArgs-- | ||
nArgs-- | ||
start++ | ||
} | ||
if step.IsVariadic() { | ||
stepNArgs-- | ||
if t.IsVariadic() { | ||
nArgs-- | ||
} | ||
for i := start; i < start+stepNArgs; i++ { | ||
conv, err := callbackArg(step.In(i)) | ||
for i := start; i < start+nArgs; i++ { | ||
conv, err := callbackArg(t.In(i)) | ||
if err != nil { | ||
return err | ||
} | ||
ai.stepArgConverters = append(ai.stepArgConverters, conv) | ||
|
||
*argConverters = append(*argConverters, conv) | ||
} | ||
if step.IsVariadic() { | ||
conv, err := callbackArg(step.In(start + stepNArgs).Elem()) | ||
if t.IsVariadic() { | ||
conv, err := callbackArg(t.In(start + nArgs).Elem()) | ||
if err != nil { | ||
return err | ||
} | ||
ai.stepVariadicConverter = conv | ||
*variadicConverter = conv | ||
// Pass -1 to sqlite so that it allows any number of | ||
// arguments. The call helper verifies that the minimum number | ||
// of arguments is present for variadic functions. | ||
stepNArgs = -1 | ||
nArgs = -1 | ||
} | ||
ai.nArgs = nArgs | ||
return nil | ||
} | ||
|
||
doneFn, found := agg.MethodByName("Done") | ||
if !found { | ||
return errors.New("SQlite aggregator doesn't have a Done() function") | ||
} | ||
done := doneFn.Type | ||
doneNArgs := done.NumIn() | ||
if agg.Kind() == reflect.Ptr { | ||
func (ai *aggInfo) setupDoneInterface(fn reflect.Method, retConverter *callbackRetConverter, implReturnsPointer bool, name string) error { | ||
t := fn.Type | ||
nArgs := t.NumIn() | ||
if implReturnsPointer { | ||
// Skip over the method receiver | ||
doneNArgs-- | ||
nArgs-- | ||
} | ||
if doneNArgs != 0 { | ||
return errors.New("SQlite aggregator Done() function must have no arguments") | ||
if nArgs != 0 { | ||
return fmt.Errorf("SQlite aggregator %s function must have no arguments", name) | ||
} | ||
if done.NumOut() != 1 && done.NumOut() != 2 { | ||
return errors.New("SQLite aggregator Done() function must return 1 or 2 values") | ||
if t.NumOut() != 1 && t.NumOut() != 2 { | ||
return fmt.Errorf("SQLite aggregator %s function must return 1 or 2 values", name) | ||
} | ||
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { | ||
return errors.New("second return value of SQLite aggregator Done() function must be error") | ||
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { | ||
return fmt.Errorf("second return value of SQLite aggregator %s function must be error", name) | ||
} | ||
|
||
conv, err := callbackRet(done.Out(0)) | ||
conv, err := callbackRet(t.Out(0)) | ||
if err != nil { | ||
return err | ||
} | ||
ai.doneRetConverter = conv | ||
ai.active = make(map[int64]reflect.Value) | ||
ai.next = 1 | ||
|
||
// ai must outlast the database connection, or we'll have dangling pointers. | ||
c.aggregators = append(c.aggregators, &ai) | ||
|
||
cname := C.CString(name) | ||
defer C.free(unsafe.Pointer(cname)) | ||
opts := C.SQLITE_UTF8 | ||
if pure { | ||
opts |= C.SQLITE_DETERMINISTIC | ||
} | ||
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) | ||
if rv != C.SQLITE_OK { | ||
return c.lastError() | ||
} | ||
*retConverter = conv | ||
return nil | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what this is doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can try my best. For the
step
andinverse
interfaces, SQLite passes usargc
, the number of arguments to consume, andargv
a double pointer to underlyingsqlite3_value
s. A simple explanation of this code is that it initializes a slice ofC.sqlite3_value
s with the length ofargc
.In the initial implementation, we used the value
1 << 30
to represent the maximum possible length of this array.#238 identified that this caused an overflow error. It was changed to use the maximum int32 size, divided by the size of nil
*C.sqlite3_value
, to limit its length without overflowing.The crux seems to be that we cannot dynamically initialize an array with a length equal to
argc
. We must specify a constant size at compile time.Once this array is initialized, we slice it down to the correct length as determined by
argc
.I'm not familiar enough with Go / C to be sure of the exact performance implications of this, but
if you examine the pointer list prior to slicing, you can see that it contains
268,435,455
((math.MaxInt31 - 1) / 8
) elements. it seems like it would be much better if there were a way to only initialize it toargc
.