Skip to content

Commit 1422b77

Browse files
committed
fix(pgdialect): postgres syntax errors for slices of pointers and json arrays #877
1 parent dbae5e6 commit 1422b77

File tree

7 files changed

+219
-12
lines changed

7 files changed

+219
-12
lines changed

dialect/pgdialect/array.go

+2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
149149
if typ.Elem().Kind() == reflect.Uint8 {
150150
return appendBytesElemValue
151151
}
152+
case reflect.Ptr:
153+
return schema.PtrAppender(d.arrayElemAppender(typ.Elem()))
152154
}
153155
return schema.Appender(d, typ)
154156
}

dialect/pgdialect/array_parser.go

+33-12
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ type arrayParser struct {
1111

1212
elem []byte
1313
err error
14+
15+
isJson bool
1416
}
1517

1618
func newArrayParser(b []byte) *arrayParser {
1719
p := new(arrayParser)
1820

19-
if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
21+
if len(b) < 2 || (b[0] != '{' && b[0] != '[') || (b[len(b)-1] != '}' && b[len(b)-1] != ']') {
2022
p.err = fmt.Errorf("pgdialect: can't parse array: %q", b)
2123
return p
2224
}
25+
p.isJson = b[0] == '['
2326

2427
p.p.Reset(b[1 : len(b)-1])
2528
return p
@@ -51,7 +54,7 @@ func (p *arrayParser) readNext() error {
5154
}
5255

5356
switch ch {
54-
case '}':
57+
case '}', ']':
5558
return io.EOF
5659
case '"':
5760
b, err := p.p.ReadSubstring(ch)
@@ -78,16 +81,34 @@ func (p *arrayParser) readNext() error {
7881
p.elem = rng
7982
return nil
8083
default:
81-
lit := p.p.ReadLiteral(ch)
82-
if bytes.Equal(lit, []byte("NULL")) {
83-
lit = nil
84+
if ch == '{' && p.isJson {
85+
json, err := p.p.ReadJSON()
86+
if err != nil {
87+
return err
88+
}
89+
90+
for {
91+
if p.p.Peek() == ',' || p.p.Peek() == ' ' {
92+
p.p.Advance()
93+
} else {
94+
break
95+
}
96+
}
97+
98+
p.elem = json
99+
return nil
100+
} else {
101+
lit := p.p.ReadLiteral(ch)
102+
if bytes.Equal(lit, []byte("NULL")) {
103+
lit = nil
104+
}
105+
106+
if p.p.Peek() == ',' {
107+
p.p.Advance()
108+
}
109+
110+
p.elem = lit
111+
return nil
84112
}
85-
86-
if p.p.Peek() == ',' {
87-
p.p.Advance()
88-
}
89-
90-
p.elem = lit
91-
return nil
92113
}
93114
}

dialect/pgdialect/array_parser_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ func TestArrayParser(t *testing.T) {
2424
{`{"1","2"}`, []string{"1", "2"}},
2525
{`{"{1}","{2}"}`, []string{"{1}", "{2}"}},
2626
{`{[1,2),[3,4)}`, []string{"[1,2)", "[3,4)"}},
27+
28+
{`[]`, []string{}},
29+
{`[{"'\"[]"}]`, []string{`{"'\"[]"}`}},
30+
{`[{"id": 1}, {"id":2, "name":"bob"}]`, []string{"{\"id\": 1}", "{\"id\":2, \"name\":\"bob\"}"}},
2731
}
2832

2933
for i, test := range tests {

dialect/pgdialect/array_test.go

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package pgdialect
2+
3+
import (
4+
"testing"
5+
6+
"github.com/uptrace/bun/schema"
7+
)
8+
9+
func ptr[T any](v T) *T {
10+
return &v
11+
}
12+
13+
func TestArrayAppend(t *testing.T) {
14+
tcases := []struct {
15+
input interface{}
16+
out string
17+
}{
18+
{
19+
input: []byte{1, 2},
20+
out: `'{1,2}'`,
21+
},
22+
{
23+
input: []*byte{ptr(byte(1)), ptr(byte(2))},
24+
out: `'{1,2}'`,
25+
},
26+
{
27+
input: []int{1, 2},
28+
out: `'{1,2}'`,
29+
},
30+
{
31+
input: []*int{ptr(1), ptr(2)},
32+
out: `'{1,2}'`,
33+
},
34+
{
35+
input: []string{"foo", "bar"},
36+
out: `'{"foo","bar"}'`,
37+
},
38+
{
39+
input: []*string{ptr("foo"), ptr("bar")},
40+
out: `'{"foo","bar"}'`,
41+
},
42+
}
43+
44+
for _, tcase := range tcases {
45+
out, err := Array(tcase.input).AppendQuery(schema.NewFormatter(New()), []byte{})
46+
if err != nil {
47+
t.Fatal(err)
48+
}
49+
50+
if string(out) != tcase.out {
51+
t.Errorf("expected output to be %s, was %s", tcase.out, string(out))
52+
}
53+
}
54+
}

dialect/pgdialect/parser.go

+36
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,39 @@ func (p *pgparser) ReadRange(ch byte) ([]byte, error) {
105105

106106
return p.buf, nil
107107
}
108+
109+
func (p *pgparser) ReadJSON() ([]byte, error) {
110+
p.Unread()
111+
112+
c, err := p.ReadByte()
113+
if err != nil {
114+
return nil, err
115+
}
116+
117+
p.buf = p.buf[:0]
118+
119+
depth := 0
120+
for {
121+
switch c {
122+
case '{':
123+
depth++
124+
case '}':
125+
depth--
126+
}
127+
128+
p.buf = append(p.buf, c)
129+
130+
if depth == 0 {
131+
break
132+
}
133+
134+
next, err := p.ReadByte()
135+
if err != nil {
136+
return nil, err
137+
}
138+
139+
c = next
140+
}
141+
142+
return p.buf, nil
143+
}

dialect/pgdialect/sqltype.go

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ func fieldSQLType(field *schema.Field) string {
8686
}
8787

8888
func sqlType(typ reflect.Type) string {
89+
if typ.Kind() == reflect.Ptr {
90+
typ = typ.Elem()
91+
}
92+
8993
switch typ {
9094
case nullStringType: // typ.Kind() == reflect.Struct, test for exact match
9195
return sqltype.VarChar

internal/dbtest/pg_test.go

+86
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/uptrace/bun"
1818
"github.com/uptrace/bun/dialect/pgdialect"
1919
"github.com/uptrace/bun/driver/pgdriver"
20+
"github.com/uptrace/bun/schema"
2021
)
2122

2223
func TestPostgresArray(t *testing.T) {
@@ -25,16 +26,20 @@ func TestPostgresArray(t *testing.T) {
2526
Array1 []string `bun:",array"`
2627
Array2 *[]string `bun:",array"`
2728
Array3 *[]string `bun:",array"`
29+
Array4 []*string `bun:",array"`
2830
}
2931

3032
db := pg(t)
3133
t.Cleanup(func() { db.Close() })
3234
mustResetModel(t, ctx, db, (*Model)(nil))
3335

36+
str1 := "hello"
37+
str2 := "world"
3438
model1 := &Model{
3539
ID: 123,
3640
Array1: []string{"one", "two", "three"},
3741
Array2: &[]string{"hello", "world"},
42+
Array4: []*string{&str1, &str2},
3843
}
3944
_, err := db.NewInsert().Model(model1).Exec(ctx)
4045
require.NoError(t, err)
@@ -56,6 +61,12 @@ func TestPostgresArray(t *testing.T) {
5661
Scan(ctx, pgdialect.Array(&strs))
5762
require.NoError(t, err)
5863
require.Nil(t, strs)
64+
65+
err = db.NewSelect().Model((*Model)(nil)).
66+
Column("array4").
67+
Scan(ctx, pgdialect.Array(&strs))
68+
require.NoError(t, err)
69+
require.Equal(t, []string{"hello", "world"}, strs)
5970
}
6071

6172
func TestPostgresArrayQuote(t *testing.T) {
@@ -877,3 +888,78 @@ func TestPostgresMultiRange(t *testing.T) {
877888
err = db.NewSelect().Model(out).Scan(ctx)
878889
require.NoError(t, err)
879890
}
891+
892+
type UserID struct {
893+
ID string
894+
}
895+
896+
func (u UserID) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
897+
v := []byte(`"` + u.ID + `"`)
898+
return append(b, v...), nil
899+
}
900+
901+
var _ schema.QueryAppender = (*UserID)(nil)
902+
903+
func (r *UserID) Scan(anySrc any) (err error) {
904+
src, ok := anySrc.([]byte)
905+
if !ok {
906+
return fmt.Errorf("pgdialect: Range can't scan %T", anySrc)
907+
}
908+
909+
r.ID = string(src)
910+
return nil
911+
}
912+
913+
var _ sql.Scanner = (*UserID)(nil)
914+
915+
func TestPostgresJSONB(t *testing.T) {
916+
type Item struct {
917+
Name string `json:"name"`
918+
}
919+
type Model struct {
920+
ID int64 `bun:",pk,autoincrement"`
921+
Item Item `bun:",type:jsonb"`
922+
ItemPtr *Item `bun:",type:jsonb"`
923+
Items []Item `bun:",type:jsonb"`
924+
ItemsP []*Item `bun:",type:jsonb"`
925+
TextItemA []UserID `bun:"type:text[]"`
926+
}
927+
928+
db := pg(t)
929+
t.Cleanup(func() { db.Close() })
930+
mustResetModel(t, ctx, db, (*Model)(nil))
931+
932+
item1 := Item{Name: "one"}
933+
item2 := Item{Name: "two"}
934+
uid1 := UserID{ID: "1"}
935+
uid2 := UserID{ID: "2"}
936+
model1 := &Model{
937+
ID: 123,
938+
Item: item1,
939+
ItemPtr: &item2,
940+
Items: []Item{item1, item2},
941+
ItemsP: []*Item{&item1, &item2},
942+
TextItemA: []UserID{uid1, uid2},
943+
}
944+
_, err := db.NewInsert().Model(model1).Exec(ctx)
945+
require.NoError(t, err)
946+
947+
model2 := new(Model)
948+
err = db.NewSelect().Model(model2).Scan(ctx)
949+
require.NoError(t, err)
950+
require.Equal(t, model1, model2)
951+
952+
var items []Item
953+
err = db.NewSelect().Model((*Model)(nil)).
954+
Column("items").
955+
Scan(ctx, pgdialect.Array(&items))
956+
require.NoError(t, err)
957+
require.Equal(t, []Item{item1, item2}, items)
958+
959+
err = db.NewSelect().Model((*Model)(nil)).
960+
Column("itemsp").
961+
Scan(ctx, pgdialect.Array(&items))
962+
require.NoError(t, err)
963+
require.Equal(t, []Item{item1, item2}, items)
964+
965+
}

0 commit comments

Comments
 (0)