From db41227ee61f318fee214f794606a716cbc1b0a6 Mon Sep 17 00:00:00 2001 From: Deng Ming Date: Sun, 3 Oct 2021 22:53:18 +0800 Subject: [PATCH] support updater --- .CHANGELOG.md | 3 +- assignment.go | 29 ++++++---- builder.go | 63 ++++++++++++++++++++- column.go | 25 +++++++-- db.go | 94 +++++++++++++++++++++++++++---- db_test.go | 51 +++++++++++++++++ expression.go | 36 ++++++++---- predicate.go | 15 ++++- predicate_test.go | 4 +- update.go | 137 +++++++++++++++++++++++++++++++++++++++++++--- update_test.go | 52 +++++++++++++----- 11 files changed, 444 insertions(+), 65 deletions(-) create mode 100644 db_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 1d9a98f..1e7ec04 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -6,4 +6,5 @@ - [Metadata API](https://github.com/gotomicro/eql/pull/16) - [tagMetaRegistry: default implementation of MetaRegistry](https://github.com/gotomicro/eql/pull/25) - [Refactor: move Insert function into db.file](https://github.com/gotomicro/eql/pull/28) -- [Selector implementation, except WHERE and HAVING](https://github.com/gotomicro/eql/pull/32) +- [Selector implementation, excluding WHERE and HAVING clauses](https://github.com/gotomicro/eql/pull/32) +- [Updater implementation, excluding WHERE clause](https://github.com/gotomicro/eql/pull/36) diff --git a/assignment.go b/assignment.go index bd71ac6..8ec8289 100644 --- a/assignment.go +++ b/assignment.go @@ -18,22 +18,27 @@ type Assignable interface { assign() } -type Assignment struct { - Column string - Value ValueExpr -} +type Assignment binaryExpr -// 实现注意: -// 1. value 是 ValueExpr -// 2. value 不是 ValueExpr,这时候看做是一个普通的值 -func Assign(column string, value interface{}) *Assignment { - panic("implement me") +func Assign(column string, value interface{}) Assignment { + var expr Expr + switch v := value.(type) { + case Expr: + expr = v + default: + expr = valueExpr{val: v} + } + return Assignment{left: C(column), op: opEQ, right: expr} } -func (a *Assignment) assign() { +func (a Assignment) assign() { panic("implement me") } -type ValueExpr interface { - value() +type valueExpr struct { + val interface{} } + +func (valueExpr) expr() (string, error){ + return "", nil +} \ No newline at end of file diff --git a/builder.go b/builder.go index 3ff3d87..e050217 100644 --- a/builder.go +++ b/builder.go @@ -14,7 +14,11 @@ package eql -import "strings" +import ( + "errors" + "github.com/gotomicro/eql/internal" + "strings" +) // QueryBuilder is used to build a query type QueryBuilder interface { @@ -61,4 +65,61 @@ func (b *builder) parameter(arg interface{}) { } b.buffer.WriteByte('?') b.args = append(b.args, arg) +} + +func (b *builder) buildExpr(expr Expr) error { + switch e := expr.(type) { + case RawExpr: + b.buffer.WriteString(string(e)) + case Column: + cm, ok := b.meta.fieldMap[e.name] + if !ok { + return internal.NewInvalidColumnError(e.name) + } + b.quote(cm.columnName) + case valueExpr: + b.parameter(e.val) + case MathExpr: + if err := b.buildBinaryExpr(binaryExpr(e)); err != nil { + return err + } + case binaryExpr: + if err := b.buildBinaryExpr(e); err != nil { + return err + } + default: + return errors.New("unsupported expr") + } + return nil +} + +func (b *builder) buildBinaryExpr(e binaryExpr) error { + err := b.buildBinarySubExpr(e.left) + if err != nil { + return err + } + b.buffer.WriteString(string(e.op)) + return b.buildBinarySubExpr(e.right) +} + +func (b *builder) buildBinarySubExpr(subExpr Expr) error { + switch r := subExpr.(type) { + case MathExpr: + b.buffer.WriteByte('(') + if err := b.buildBinaryExpr(binaryExpr(r)); err != nil { + return err + } + b.buffer.WriteByte(')') + case binaryExpr: + b.buffer.WriteByte('(') + if err := b.buildBinaryExpr(r); err != nil { + return err + } + b.buffer.WriteByte(')') + default: + if err := b.buildExpr(r); err != nil { + return err + } + } + return nil } \ No newline at end of file diff --git a/column.go b/column.go index 39e8a35..4504faf 100644 --- a/column.go +++ b/column.go @@ -34,12 +34,20 @@ func (c Column) As(alias string) Selectable { } } -func (Column) Inc(val interface{}) MathExpr { - panic("implement me") +func (c Column) Add(val interface{}) MathExpr { + return MathExpr{ + left: c, + op: opAdd, + right: valueOf(val), + } } -func (Column) Times(val interface{}) MathExpr { - panic("implement me") +func (c Column) Multi(val interface{}) MathExpr { + return MathExpr{ + left: c, + op: opMulti, + right: valueOf(val), + } } func (Column) assign() { @@ -73,3 +81,12 @@ func Columns(cs...string) columns { } } +func valueOf(val interface{}) Expr { + switch v := val.(type) { + case Expr: + return v + default: + return valueExpr{val: val} + } +} + diff --git a/db.go b/db.go index 67ec5c3..082d76a 100644 --- a/db.go +++ b/db.go @@ -14,7 +14,10 @@ package eql -import "strings" +import ( + "reflect" + "strings" +) // DBOption configure DB type DBOption func(db *DB) @@ -22,14 +25,16 @@ type DBOption func(db *DB) // DB represents a database type DB struct { metaRegistry MetaRegistry - dialect Dialect + dialect Dialect + nullAssertFunc NullAssertFunc } // New returns DB. It's the entry of EQL func New(opts ...DBOption) *DB { db := &DB{ - metaRegistry: defaultMetaRegistry, - dialect: mysql, + metaRegistry: defaultMetaRegistry, + dialect: mysql, + nullAssertFunc: NilAsNullFunc, } for _, o := range opts { o(db) @@ -40,11 +45,7 @@ func New(opts ...DBOption) *DB { // Select starts a select query. If columns are empty, all columns will be fetched func (db *DB) Select(columns ...Selectable) *Selector { return &Selector{ - builder: builder{ - registry: db.metaRegistry, - dialect: db.dialect, - buffer: &strings.Builder{}, - }, + builder: db.builder(), columns: columns, } } @@ -54,11 +55,82 @@ func (*DB) Delete() *Deleter { panic("implement me") } -func (*DB) Update(table interface{}) *Updater { - panic("implement me") +func (db *DB) Update(table interface{}) *Updater { + return &Updater{ + builder: db.builder(), + table: table, + nullAssertFunc: db.nullAssertFunc, + } } // Insert generate Inserter to builder insert query func (db *DB) Insert() *Inserter { return &Inserter{} } + +func (db *DB) builder() builder { + return builder{ + registry: db.metaRegistry, + dialect: db.dialect, + buffer: &strings.Builder{}, + } +} + +func WithNullAssertFunc(nullable NullAssertFunc) DBOption { + return func(db *DB) { + db.nullAssertFunc = nullable + } +} + +// NullAssertFunc determined if the value is NULL. +// As we know, there is a gap between NULL and nil +// There are two kinds of nullAssertFunc +// 1. nil = NULL, see NilAsNullFunc +// 2. zero value = NULL, see ZeroAsNullFunc +type NullAssertFunc func(val interface{}) bool + +// NilAsNullFunc use the strict definition of "nullAssertFunc" +// if and only if the val is nil, indicates value is null +func NilAsNullFunc(val interface{}) bool { + return val == nil +} + +// ZeroAsNullFunc means "zero value = null" +func ZeroAsNullFunc(val interface{}) bool { + if val == nil{ + return true + } + switch v := val.(type) { + case int: + return v == 0 + case int8: + return v == 0 + case int16: + return v == 0 + case int32: + return v == 0 + case int64: + return v == 0 + case uint: + return v == 0 + case uint8: + return v == 0 + case uint16: + return v == 0 + case uint32: + return v == 0 + case uint64: + return v == 0 + case float32: + return v == 0 + case float64: + return v == 0 + case bool: + return v + case string: + return v == "" + default: + valRef := reflect.ValueOf(val) + return valRef.IsZero() + } +} diff --git a/db_test.go b/db_test.go new file mode 100644 index 0000000..604319b --- /dev/null +++ b/db_test.go @@ -0,0 +1,51 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eql + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestStrictNullableFunc(t *testing.T) { + str := "Hello" + assert.False(t, NilAsNullFunc(str)) + str = "" + assert.False(t, NilAsNullFunc(str)) + var err error + assert.True(t, NilAsNullFunc(err)) + + var i int + assert.False(t, NilAsNullFunc(i)) +} + +func TestZeroAsNullableFunc(t *testing.T) { + assert.True(t, ZeroAsNullFunc(0)) + assert.True(t, ZeroAsNullFunc(int8(0))) + assert.True(t, ZeroAsNullFunc(int16(0))) + assert.True(t, ZeroAsNullFunc(int32(0))) + assert.True(t, ZeroAsNullFunc(int64(0))) + assert.True(t, ZeroAsNullFunc(uint(0))) + assert.True(t, ZeroAsNullFunc(uint8(0))) + assert.True(t, ZeroAsNullFunc(uint16(0))) + assert.True(t, ZeroAsNullFunc(uint32(0))) + assert.True(t, ZeroAsNullFunc(uint64(0))) + assert.True(t, ZeroAsNullFunc(float32(0))) + assert.True(t, ZeroAsNullFunc(float64(0))) + assert.True(t, ZeroAsNullFunc("")) + var err error + assert.True(t, ZeroAsNullFunc(err)) +} + diff --git a/expression.go b/expression.go index c6d7111..4d5fac7 100644 --- a/expression.go +++ b/expression.go @@ -19,15 +19,6 @@ type Expr interface { expr() (string, error) } -type funcCall struct { - fn string - args []Expr -} - -func (*funcCall) expr() (string, error) { - panic("implement me") -} - // RawExpr uses string as Expr type RawExpr string @@ -36,11 +27,12 @@ func Raw(expr string) RawExpr { return RawExpr(expr) } + func (r RawExpr) expr() (string, error) { return string(r), nil } -func (r RawExpr) selected() {} +func (RawExpr) selected() {} type binaryExpr struct { left Expr @@ -48,6 +40,28 @@ type binaryExpr struct { right Expr } +func (binaryExpr) expr() (string, error) { + return "", nil +} + type MathExpr binaryExpr -func (MathExpr) assign() {} \ No newline at end of file +func (m MathExpr) Add(val interface{}) Expr { + return MathExpr{ + left: m, + op: opAdd, + right: valueOf(val), + } +} + +func (m MathExpr) Multi(val interface{}) MathExpr { + return MathExpr{ + left: m, + op: opMulti, + right: valueOf(val), + } +} + +func (MathExpr) expr() (string, error) { + return "", nil +} \ No newline at end of file diff --git a/predicate.go b/predicate.go index 518465e..da71ead 100644 --- a/predicate.go +++ b/predicate.go @@ -14,9 +14,20 @@ package eql -type op struct { +type op string -} +const ( + opLT = op("<") + opLTEQ = op("<=") + opGT = op(">") + opGTEQ= op(">=") + opEQ = op("=") + opNEQ = op("!=") + opAdd = op("+") + opMinus = op("-") + opMulti = op("*") + opDiv = op("/") +) // Predicate will be used in Where Or Having type Predicate binaryExpr diff --git a/predicate_test.go b/predicate_test.go index b11988e..72587b9 100644 --- a/predicate_test.go +++ b/predicate_test.go @@ -78,14 +78,14 @@ func TestPredicate_P(t *testing.T) { { name: "cross columns mathematical", builder: New().Select(Columns("Id")).From(&TestModel{Id: 10}). - Where(P("Age").GT(C("Id").Inc(40))), + Where(P("Age").GT(C("Id").Add(40))), wantSql: "SELECT `id` FROM test_model WHERE `age`>`id`+?", wantArgs: []interface{}{40}, }, { name: "cross columns mathematical", builder: New().Select(Columns("Id")).From(&TestModel{Id: 10}). - Where(P("Age").GT(C("Id").Times(C("Age").Inc(66)))), + Where(P("Age").GT(C("Id").Multi(C("Age").Add(66)))), wantSql: "SELECT `id` FROM test_model WHERE `age`>`id`*(`age`+?)", wantArgs: []interface{}{66}, }, diff --git a/update.go b/update.go index ab21880..a9bb479 100644 --- a/update.go +++ b/update.go @@ -14,22 +14,143 @@ package eql -type Updater struct { +import ( + "errors" + "fmt" + "github.com/gotomicro/eql/internal" + "reflect" +) +// Updater is the builder responsible for building UPDATE query +type Updater struct { + builder + table interface{} + tableEle reflect.Value + where []Predicate + assigns []Assignable + nullAssertFunc NullAssertFunc } +// Build returns UPDATE query func (u *Updater) Build() (*Query, error) { - panic("implement me") + var err error + u.meta, err = u.registry.Get(u.table) + if err != nil { + return nil, err + } + + u.tableEle = reflect.ValueOf(u.table).Elem() + u.args = make([]interface{}, 0, len(u.meta.columns)) + + u.buffer.WriteString("UPDATE ") + u.quote(u.meta.tableName) + u.buffer.WriteString(" SET ") + if len(u.assigns) == 0 { + err = u.buildDefaultColumns() + } else { + err = u.buildAssigns() + } + if err != nil { + return nil, err + } + + // TODO WHERE + + u.end() + return &Query{ + SQL: u.buffer.String(), + Args: u.args, + }, nil +} + +func (u *Updater) buildAssigns() error { + has := false + for _, assign := range u.assigns { + if has { + u.comma() + } + switch a := assign.(type) { + case Column: + set, err := u.buildColumn(a.name) + if err != nil { + return err + } + has = has || set + case columns: + for _, c := range a.cs { + if has { + u.comma() + } + set, err := u.buildColumn(c) + if err != nil { + return err + } + has = has || set + } + case Assignment: + if err := u.buildExpr(binaryExpr(a)); err != nil { + return err + } + has = true + default: + return fmt.Errorf("eql: unsupported assignment %v", a) + } + } + if !has { + return errors.New("eql: value unset") + } + return nil +} + +func (u *Updater) buildColumn(field string) (bool, error) { + c, ok := u.meta.fieldMap[field] + if !ok { + return false, internal.NewInvalidColumnError(field) + } + return u.setColumn(c), nil +} + +func (u *Updater) setColumn(c *ColumnMeta) bool { + val := u.tableEle.FieldByName(c.fieldName).Interface() + isNull := u.nullAssertFunc(val) + if !isNull { + u.quote(c.columnName) + u.buffer.WriteByte('=') + u.parameter(val) + return true + } + return false +} + +func (u *Updater) buildDefaultColumns() error { + has := false + for _, c := range u.meta.columns { + if has { + u.buffer.WriteByte(',') + } + val := u.tableEle.FieldByName(c.fieldName).Interface() + isNull := u.nullAssertFunc(val) + if !isNull { + u.quote(c.columnName) + u.buffer.WriteByte('=') + u.parameter(val) + has = true + } + } + if !has { + return errors.New("value unset") + } + return nil } -// Set: -// 1. 更新字段,值从 entity 里面读,也就是从 db.Update(table) 的 table 里面读 -// 2. 有特定的指 Set("id", "123") -// 更新多个字段,都是从entity里面读数据,那么我需要 Set(Assign("id", fromEntity), Assign("id", fromEntity)) +// Set represents SET clause func (u *Updater) Set(assigns...Assignable) *Updater { - panic("implement me") + u.assigns = assigns + return u } +// Where represents WHERE clause func (u *Updater) Where(predicates...Predicate) *Updater { - panic("implement me") + u.where = predicates + return u } diff --git a/update_test.go b/update_test.go index 3f7ec15..f6bda8b 100644 --- a/update_test.go +++ b/update_test.go @@ -15,6 +15,7 @@ package eql import ( + "github.com/gotomicro/eql/internal" "github.com/stretchr/testify/assert" "testing" ) @@ -30,46 +31,68 @@ func TestUpdater_Set(t *testing.T) { { name: "no set", builder: New().Update(tm), - wantSql: "UPDATE `test_model` SET `id`=?, `first_name`=?, `age`=?, `last_name`=?;", + wantSql: "UPDATE `test_model` SET `id`=?,`first_name`=?,`age`=?,`last_name`=?;", wantArgs: []interface{}{int64(12), "Tom", int8(18), "Jerry"}, }, { name: "set columns", builder: New().Update(tm).Set(Columns("FirstName", "Age")), - wantSql: "UPDATE `test_model` SET first_name`=?, `age`=?;", + wantSql: "UPDATE `test_model` SET `first_name`=?,`age`=?;", wantArgs: []interface{}{"Tom", int8(18)}, }, + { + name: "set invalid columns", + builder: New().Update(tm).Set(Columns("FirstNameInvalid", "Age")), + wantErr: internal.NewInvalidColumnError("FirstNameInvalid"), + }, { name: "set c2", builder: New().Update(tm).Set(C("FirstName"), C("Age")), - wantSql: "UPDATE `test_model` SET first_name`=?, `age`=?;", + wantSql: "UPDATE `test_model` SET `first_name`=?,`age`=?;", wantArgs: []interface{}{"Tom", int8(18)}, }, { - name: "set c2", + name: "set invalid c2", + builder: New().Update(tm).Set(C("FirstNameInvalid"), C("Age")), + wantErr: internal.NewInvalidColumnError("FirstNameInvalid"), + }, + + { + name: "set assignment", builder: New().Update(tm).Set(C("FirstName"), Assign("Age", 30)), - wantSql: "UPDATE `test_model` SET first_name`=?, `age`=?;", + wantSql: "UPDATE `test_model` SET `first_name`=?,`age`=?;", wantArgs: []interface{}{"Tom", 30}, }, + { + name: "set invalid assignment", + builder: New().Update(tm).Set(C("FirstName"), Assign("InvalidAge", 30)), + wantErr: internal.NewInvalidColumnError("InvalidAge"), + }, { name: "set age+1", - builder: New().Update(tm).Set(C("FirstName"), Assign("Age", C("Age").Inc(1))), - wantSql: "UPDATE `test_model` SET first_name`=?, `age`=`age`+?;", + builder: New().Update(tm).Set(C("FirstName"), Assign("Age", C("Age").Add(1))), + wantSql: "UPDATE `test_model` SET `first_name`=?,`age`=(`age`+?);", wantArgs: []interface{}{"Tom", 1}, }, { name: "set age=id+1", - builder: New().Update(tm).Set(C("FirstName"), Assign("Age", C("Id").Inc(10))), - wantSql: "UPDATE `test_model` SET first_name`=?, `age`=`id`+?;", + builder: New().Update(tm).Set(C("FirstName"), Assign("Age", C("Id").Add(10))), + wantSql: "UPDATE `test_model` SET `first_name`=?,`age`=(`id`+?);", wantArgs: []interface{}{"Tom", 10}, }, { name: "set age=id+(age*100)", - builder: New().Update(tm).Set(C("FirstName"), Assign("Age", C("Id").Inc(C("Age").Times(100)))), - wantSql: "UPDATE `test_model` SET first_name`=?, `age`=`id`+(`age`*?);", + builder: New().Update(tm).Set(C("FirstName"), Assign("Age", C("Id").Add(C("Age").Multi(100)))), + wantSql: "UPDATE `test_model` SET `first_name`=?,`age`=(`id`+(`age`*?));", wantArgs: []interface{}{"Tom", 100}, }, + { + name: "set age=(id+(age*100))*110", + builder: New().Update(tm).Set(C("FirstName"), Assign("Age", C("Id").Add(C("Age").Multi(100)).Multi(110))), + wantSql: "UPDATE `test_model` SET `first_name`=?,`age`=((`id`+(`age`*?))*?);", + wantArgs: []interface{}{"Tom", 100, 110}, + }, } for _, tc := range testCases { @@ -77,8 +100,11 @@ func TestUpdater_Set(t *testing.T) { t.Run(c.name, func(t *testing.T) { query, err := tc.builder.Build() assert.Equal(t, err, c.wantErr) - assert.Equal(t, query.SQL, c.wantSql) - assert.Equal(t, query.Args, c.wantArgs) + if err != nil { + return + } + assert.Equal(t, c.wantSql, query.SQL) + assert.Equal(t, c.wantArgs, query.Args) }) } }