Skip to content

Commit 77eb2da

Browse files
authored
Escape query (go-rel#23)
* use shared config struct * escape query for insert update and delete * escape field on query builder * wip aggregate function * fix count on postgres * test for query * escape function argument and cache escaped field * combine limit and offset * added test for aggregate * fix tests
1 parent f1bb7a1 commit 77eb2da

16 files changed

+700
-337
lines changed

adapter.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ package grimoire
22

33
// Adapter interface
44
type Adapter interface {
5-
Count(Query, ...Logger) (int, error)
65
All(Query, interface{}, ...Logger) (int, error)
7-
Delete(Query, ...Logger) error
6+
Aggregate(Query, interface{}, ...Logger) error
87
Insert(Query, map[string]interface{}, ...Logger) (interface{}, error)
98
InsertAll(Query, []string, []map[string]interface{}, ...Logger) ([]interface{}, error)
109
Update(Query, map[string]interface{}, ...Logger) error
10+
Delete(Query, ...Logger) error
1111

1212
Begin() (Adapter, error)
1313
Commit() error

adapter/mysql/mysql.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,16 @@ var _ grimoire.Adapter = (*Adapter)(nil)
3434
func Open(dsn string) (*Adapter, error) {
3535
var err error
3636

37-
adapter := &Adapter{sql.New(errorFunc, incrementFunc, sql.Placeholder("?"))}
37+
adapter := &Adapter{
38+
Adapter: &sql.Adapter{
39+
Config: &sql.Config{
40+
Placeholder: "?",
41+
EscapeChar: "`",
42+
IncrementFunc: incrementFunc,
43+
ErrorFunc: errorFunc,
44+
},
45+
},
46+
}
3847
adapter.DB, err = db.Open("mysql", dsn)
3948

4049
return adapter, err

adapter/mysql/mysql_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func dsn() string {
6161
return "root@(127.0.0.1:3306)/grimoire_test?charset=utf8&parseTime=True&loc=Local"
6262
}
6363

64-
func TestAdapter__specs(t *testing.T) {
64+
func TestAdapter_specs(t *testing.T) {
6565
adapter, err := Open(dsn())
6666
paranoid.Panic(err, "failed to open database connection")
6767
defer adapter.Close()
@@ -76,8 +76,8 @@ func TestAdapter__specs(t *testing.T) {
7676
// Preload specs
7777
specs.Preload(t, repo)
7878

79-
// Count Specs
80-
specs.Count(t, repo)
79+
// Aggregate Specs
80+
specs.Aggregate(t, repo)
8181

8282
// Insert Specs
8383
specs.Insert(t, repo)

adapter/postgres/postgres.go

+14-13
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,16 @@ var _ grimoire.Adapter = (*Adapter)(nil)
3333
func Open(dsn string) (*Adapter, error) {
3434
var err error
3535

36-
adapter := &Adapter{sql.New(errorFunc, nil,
37-
sql.Placeholder("$"),
38-
sql.Ordinal(true),
39-
sql.InsertDefaultValues(true)),
36+
adapter := &Adapter{
37+
Adapter: &sql.Adapter{
38+
Config: &sql.Config{
39+
Placeholder: "$",
40+
EscapeChar: "\"",
41+
Ordinal: true,
42+
InsertDefaultValues: true,
43+
ErrorFunc: errorFunc,
44+
},
45+
},
4046
}
4147
adapter.DB, err = db.Open("postgres", dsn)
4248

@@ -45,9 +51,7 @@ func Open(dsn string) (*Adapter, error) {
4551

4652
// Insert inserts a record to database and returns its id.
4753
func (adapter *Adapter) Insert(query grimoire.Query, changes map[string]interface{}, loggers ...grimoire.Logger) (interface{}, error) {
48-
statement, args := sql.NewBuilder(adapter.Placeholder, adapter.Ordinal, adapter.InsertDefaultValues).
49-
Returning("id").
50-
Insert(query.Collection, changes)
54+
statement, args := sql.NewBuilder(adapter.Config).Returning("id").Insert(query.Collection, changes)
5155

5256
var result struct {
5357
ID int64
@@ -59,7 +63,7 @@ func (adapter *Adapter) Insert(query grimoire.Query, changes map[string]interfac
5963

6064
// InsertAll inserts multiple records to database and returns its ids.
6165
func (adapter *Adapter) InsertAll(query grimoire.Query, fields []string, allchanges []map[string]interface{}, loggers ...grimoire.Logger) ([]interface{}, error) {
62-
statement, args := sql.NewBuilder(adapter.Placeholder, adapter.Ordinal, adapter.InsertDefaultValues).Returning("id").InsertAll(query.Collection, fields, allchanges)
66+
statement, args := sql.NewBuilder(adapter.Config).Returning("id").InsertAll(query.Collection, fields, allchanges)
6367

6468
var result []struct {
6569
ID int64
@@ -81,11 +85,8 @@ func (adapter *Adapter) Begin() (grimoire.Adapter, error) {
8185

8286
return &Adapter{
8387
&sql.Adapter{
84-
Placeholder: adapter.Placeholder,
85-
Ordinal: adapter.Ordinal,
86-
IncrementFunc: adapter.IncrementFunc,
87-
ErrorFunc: adapter.ErrorFunc,
88-
Tx: Tx,
88+
Config: adapter.Config,
89+
Tx: Tx,
8990
},
9091
}, err
9192
}

adapter/postgres/postgres_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func dsn() string {
6161
return "postgres://postgres@localhost/grimoire_test?sslmode=disable"
6262
}
6363

64-
func TestAdapter__specs(t *testing.T) {
64+
func TestAdapter_specs(t *testing.T) {
6565
adapter, err := Open(dsn())
6666
paranoid.Panic(err, "failed to open database connection")
6767
defer adapter.Close()
@@ -76,8 +76,8 @@ func TestAdapter__specs(t *testing.T) {
7676
// Preload specs
7777
specs.Preload(t, repo)
7878

79-
// Count Specs
80-
specs.Count(t, repo)
79+
// Aggregate Specs
80+
specs.Aggregate(t, repo)
8181

8282
// Insert Specs
8383
specs.Insert(t, repo)

adapter/specs/count.go renamed to adapter/specs/aggregate.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import (
88
"github.com/stretchr/testify/assert"
99
)
1010

11-
// Count tests count specifications.
12-
func Count(t *testing.T, repo grimoire.Repo) {
11+
// Aggregate tests count specifications.
12+
func Aggregate(t *testing.T, repo grimoire.Repo) {
1313
// preparte tests data
1414
user := User{Name: "name1", Gender: "male", Age: 10}
1515
repo.From(users).MustSave(&user)
@@ -36,12 +36,24 @@ func Count(t *testing.T, repo grimoire.Repo) {
3636
repo.From(users).Where(c.NotLike(name, "noname%")),
3737
repo.From(users).Where(c.Fragment("id > 0")),
3838
repo.From(users).Where(c.Not(c.Eq(id, 1), c.Eq(name, "name1"), c.Eq(age, 10))),
39+
repo.From(users).Group("gender"),
40+
repo.From(users).Group("age").Having(c.Gt(age, 10)),
3941
}
4042

4143
for _, query := range tests {
42-
statement, _ := builder.Find(query.Select("COUNT(*) AS count"))
43-
t.Run("Count|"+statement, func(t *testing.T) {
44-
_, err := query.Count()
44+
field := "*"
45+
if len(query.GroupFields) != 0 {
46+
field = query.GroupFields[0]
47+
}
48+
49+
statement, _ := builder.Find(query.Select(field, "count("+field+") AS sum"))
50+
t.Run("Aggregate|"+statement, func(t *testing.T) {
51+
var out []struct {
52+
Count int
53+
}
54+
55+
err := query.Aggregate("count", field, &out)
56+
assert.True(t, len(out) > 0)
4557
assert.Nil(t, err)
4658
})
4759
}

adapter/specs/specs.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ const (
5555
address = c.I("address")
5656
)
5757

58-
var builder = sql.NewBuilder("?", false, false)
58+
var builder = sql.NewBuilder(&sql.Config{
59+
Placeholder: "?",
60+
EscapeChar: "`",
61+
})
5962

6063
func assertConstraint(t *testing.T, err error, kind errors.Kind, field string) {
6164
assert.NotNil(t, err)

0 commit comments

Comments
 (0)