Skip to content

feat(builder): add slice argument support #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

import (
"fmt"
"io"
"reflect"
"strings"
)

// Builder is a raw SQL query builder.
// The zero value is ready to use.
// Do not copy a non-zero Builder.
// Do not reuse a single Builder for multiple queries.
type Builder struct {
query strings.Builder
args []any
Expand All @@ -22,8 +23,8 @@
// In addition, Appendf supports %?, %$, and %@ verbs, which are automatically expanded to the query placeholders ?, $N, and @pN,
// where N is the auto-incrementing counter.
// The corresponding arguments can then be accessed with the [Builder.Args] method.
//
// IMPORTANT: to avoid SQL injections, make sure to pass arguments from user input with placeholder verbs.
// Always test your queries.
//
// Placeholder verbs map to the following database placeholders:
// - MySQL, SQLite: %? -> ?
Expand Down Expand Up @@ -53,7 +54,7 @@
return query
}

// Args returns the argument slice.
// Args returns the query arguments.
func (b *Builder) Args() []any { return b.args }

type argument struct {
Expand All @@ -65,26 +66,59 @@
func (a argument) Format(s fmt.State, verb rune) {
switch verb {
case '?', '$', '@':
a.builder.args = append(a.builder.args, a.value)
if a.builder.placeholder == 0 {
a.builder.placeholder = verb
}
if a.builder.placeholder != verb {
a.builder.placeholder = -1
}
default:
format := fmt.FormatString(s, verb)
fmt.Fprintf(s, format, a.value)
return
}

if s.Flag('+') {
a.writeSlice(s, verb)
} else {
a.writePlaceholder(s, verb)
a.builder.args = append(a.builder.args, a.value)
}
}

func (a argument) writePlaceholder(w io.Writer, verb rune) {
switch verb {
case '?': // MySQL, SQLite
fmt.Fprint(s, "?")
fmt.Fprint(w, "?")
case '$': // PostgreSQL
a.builder.counter++
fmt.Fprintf(s, "$%d", a.builder.counter)
fmt.Fprintf(w, "$%d", a.builder.counter)
case '@': // MSSQL
a.builder.counter++
fmt.Fprintf(s, "@p%d", a.builder.counter)
default:
format := fmt.FormatString(s, verb)
fmt.Fprintf(s, format, a.value)
fmt.Fprintf(w, "@p%d", a.builder.counter)
}
}

func (a argument) writeSlice(w io.Writer, verb rune) {
slice := reflect.ValueOf(a.value)
if slice.Kind() != reflect.Slice {
panic("queries: %+ argument must be a slice")

Check warning on line 105 in builder.go

View check run for this annotation

Codecov / codecov/patch

builder.go#L105

Added line #L105 was not covered by tests
}

if slice.Len() == 0 {
fmt.Fprint(w, "NULL") // "IN (NULL)" is valid SQL.
return

Check warning on line 110 in builder.go

View check run for this annotation

Codecov / codecov/patch

builder.go#L109-L110

Added lines #L109 - L110 were not covered by tests
}

args := reflect.ValueOf(a.builder.args)

for i := range slice.Len() {
if i > 0 {
fmt.Fprint(w, ", ")
}
a.writePlaceholder(w, verb)
args = reflect.Append(args, slice.Index(i))
}

a.builder.args = args.Interface().([]any)
}
22 changes: 15 additions & 7 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestBuilder(t *testing.T) {
assert.Equal[E](t, qb.Args(), []any{42, "test", false})
}

func TestBuilder_placeholders(t *testing.T) {
func TestBuilder_dialects(t *testing.T) {
tests := map[string]struct {
format string
query string
Expand Down Expand Up @@ -50,31 +50,39 @@ func TestBuilder_placeholders(t *testing.T) {
}
}

func TestBuilder_slice(t *testing.T) {
var qb queries.Builder
qb.Appendf("SELECT * FROM tbl WHERE foo IN (%+$)", []int{1, 2, 3})

assert.Equal[E](t, qb.Query(), "SELECT * FROM tbl WHERE foo IN ($1, $2, $3)")
assert.Equal[E](t, qb.Args(), []any{1, 2, 3})
}

func TestBuilder_badQuery(t *testing.T) {
tests := map[string]struct {
appends func(*queries.Builder)
appendf func(*queries.Builder)
panicMsg string
}{
"bad verb": {
appends: func(qb *queries.Builder) {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %d FROM tbl", "foo")
},
panicMsg: "queries: bad query: SELECT %!d(string=foo) FROM tbl",
},
"too few arguments": {
appends: func(qb *queries.Builder) {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %s FROM tbl")
},
panicMsg: "queries: bad query: SELECT %!s(MISSING) FROM tbl",
},
"too many arguments": {
appends: func(qb *queries.Builder) {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT %s FROM tbl", "foo", "bar")
},
panicMsg: "queries: bad query: SELECT foo FROM tbl%!(EXTRA queries.argument=bar)",
},
"different placeholders": {
appends: func(qb *queries.Builder) {
appendf: func(qb *queries.Builder) {
qb.Appendf("SELECT * FROM tbl WHERE foo = %? AND bar = %$ AND baz = %@", 1, 2, 3)
},
panicMsg: "queries: different placeholders used",
Expand All @@ -84,7 +92,7 @@ func TestBuilder_badQuery(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
var qb queries.Builder
tt.appends(&qb)
tt.appendf(&qb)
assert.Panics[E](t, func() { _ = qb.Query() }, tt.panicMsg)
})
}
Expand Down
Loading