From be36c9d3b016309be6ad447d48d2e98d4dfc5313 Mon Sep 17 00:00:00 2001 From: Deng Ming Date: Mon, 16 May 2022 19:25:36 +0800 Subject: [PATCH] =?UTF-8?q?internal/value:=20=E6=8A=BD=E8=B1=A1=20Value=20?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E4=B8=8E=E5=9F=BA=E4=BA=8E=E5=8F=8D=E5=B0=84?= =?UTF-8?q?=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将元数据抽象挪到 internal/model 包 - 将方言抽象挪到 intenal/dialect 包 - 抽象了 Value 接口并且提供了基于反射的实现 --- .github/workflows/go.yml | 2 +- README.md | 4 +- builder.go | 26 +++--- db.go | 14 +-- db_test.go | 10 ++- delete.go | 2 +- go.mod | 2 +- insert.go | 21 ++--- dialect.go => internal/dialect/dialect.go | 14 +-- internal/{ => error}/error.go | 9 +- model.go => internal/model/model.go | 74 ++++++++++------ model_test.go => internal/model/model_test.go | 85 ++++++++++-------- internal/value/reflect/value.go | 44 ++++++++++ internal/value/reflect/value_test.go | 88 +++++++++++++++++++ internal/{common_fun.go => value/value.go} | 29 ++---- select.go | 32 +++---- select_test.go | 12 +-- update.go | 28 +++--- update_test.go | 8 +- 19 files changed, 332 insertions(+), 172 deletions(-) rename dialect.go => internal/dialect/dialect.go (87%) rename internal/{ => error}/error.go (87%) rename model.go => internal/model/model.go (70%) rename model_test.go => internal/model/model_test.go (58%) create mode 100644 internal/value/reflect/value.go create mode 100644 internal/value/reflect/value_test.go rename internal/{common_fun.go => value/value.go} (55%) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index eb427566..a9a550ab 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -28,7 +28,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.17 + go-version: 1.18 - name: Build run: go build -v ./... diff --git a/README.md b/README.md index 37caba9a..379e76fc 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ 简单的 ORM 框架。 -> 注意:这是一个全中文的仓库。这意味着注释、文档和错误信息,都会是中文的。介意的用户可以选择 GORM,这也是一个杰出的 ORM 仓库 +请使用 Go 1.18 以上版本。 + +> 注意:这是一个全中文的仓库。这意味着注释、文档和错误信息,都会是中文。介意的用户可以选择别的 ORM 仓库,但是不必来反馈说希望提供英文版本,我们是不会提供的。我们缺乏足够的英文水平良好的开发者,这也是为何我选择将这个项目做成全中文的原因。 ## SQL 2003 标准 理论上来说,我们计划支持 [SQL 2003 standard](https://ronsavage.github.io/SQL/sql-2003-2.bnf.html#query%20specification). 不过据我们所知,并不是所有的数据库都支持全部的 SQL 2003 标准,所以用户还是需要进一步检查目标数据库的语法。 diff --git a/builder.go b/builder.go index e9d64803..9e994567 100644 --- a/builder.go +++ b/builder.go @@ -16,7 +16,9 @@ package eorm import ( "errors" - "github.com/gotomicro/eorm/internal" + dialect2 "github.com/gotomicro/eorm/internal/dialect" + error2 "github.com/gotomicro/eorm/internal/error" + "github.com/gotomicro/eorm/internal/model" "github.com/valyala/bytebufferpool" ) @@ -32,21 +34,21 @@ type Query struct { } type builder struct { - registry MetaRegistry - dialect Dialect + registry model.MetaRegistry + dialect dialect2.Dialect // Use bytebufferpool to reduce memory allocation. // After using buffer, it must be put back in bytebufferpool. // Call bytebufferpool.Get() to get a buffer, call bytebufferpool.Put() to put buffer back to bytebufferpool. buffer *bytebufferpool.ByteBuffer - meta *TableMeta + meta *model.TableMeta args []interface{} aliases map[string]struct{} } func (b *builder) quote(val string) { - _ = b.buffer.WriteByte(b.dialect.quote) + _ = b.buffer.WriteByte(b.dialect.Quote) _, _ = b.buffer.WriteString(val) - _ = b.buffer.WriteByte(b.dialect.quote) + _ = b.buffer.WriteByte(b.dialect.Quote) } func (b *builder) space() { @@ -81,11 +83,11 @@ func (b *builder) buildExpr(expr Expr) error { b.quote(e.name) return nil } - cm, ok := b.meta.fieldMap[e.name] + cm, ok := b.meta.FieldMap[e.name] if !ok { - return internal.NewInvalidColumnError(e.name) + return error2.NewInvalidColumnError(e.name) } - b.quote(cm.columnName) + b.quote(cm.ColumnName) } case Aggregate: if err := b.buildHavingAggregate(e); err != nil { @@ -123,11 +125,11 @@ func (b *builder) buildPredicates(predicates []Predicate) error { func (b *builder) buildHavingAggregate(aggregate Aggregate) error { _, _ = b.buffer.WriteString(aggregate.fn) _ = b.buffer.WriteByte('(') - cMeta, ok := b.meta.fieldMap[aggregate.arg] + cMeta, ok := b.meta.FieldMap[aggregate.arg] if !ok { - return internal.NewInvalidColumnError(aggregate.arg) + return error2.NewInvalidColumnError(aggregate.arg) } - b.quote(cMeta.columnName) + b.quote(cMeta.ColumnName) _ = b.buffer.WriteByte(')') return nil } diff --git a/db.go b/db.go index b4056537..293efe36 100644 --- a/db.go +++ b/db.go @@ -15,6 +15,8 @@ package eorm import ( + "github.com/gotomicro/eorm/internal/dialect" + "github.com/gotomicro/eorm/internal/model" "github.com/valyala/bytebufferpool" ) @@ -23,16 +25,16 @@ type DBOption func(db *DB) // DB represents a database type DB struct { - metaRegistry MetaRegistry - dialect Dialect + metaRegistry model.MetaRegistry + dialect dialect.Dialect } // NewDB returns DB. // By default, it will create an instance of MetaRegistry and use MySQL as the dialect func NewDB(opts ...DBOption) *DB { db := &DB{ - metaRegistry: &tagMetaRegistry{}, - dialect: MySQL, + metaRegistry: model.NewMetaRegistry(), + dialect: dialect.MySQL, } for _, o := range opts { o(db) @@ -41,14 +43,14 @@ func NewDB(opts ...DBOption) *DB { } // DBWithMetaRegistry specify the MetaRegistry -func DBWithMetaRegistry(registry MetaRegistry) DBOption { +func DBWithMetaRegistry(registry model.MetaRegistry) DBOption { return func(db *DB) { db.metaRegistry = registry } } // DBWithDialect specify dialect -func DBWithDialect(dialect Dialect) DBOption { +func DBWithDialect(dialect dialect.Dialect) DBOption { return func(db *DB) { db.dialect = dialect } diff --git a/db_test.go b/db_test.go index ffe2050d..d068517b 100644 --- a/db_test.go +++ b/db_test.go @@ -16,19 +16,21 @@ package eorm import ( "fmt" + "github.com/gotomicro/eorm/internal/dialect" + "github.com/gotomicro/eorm/internal/model" ) func ExampleNew() { // case1 without DBOption db := NewDB() - fmt.Printf("case1 dialect: %s\n", db.dialect.name) + fmt.Printf("case1 dialect: %s\n", db.dialect.Name) // case2 use DBOption - db = NewDB(DBWithDialect(SQLite)) - fmt.Printf("case2 dialect: %s\n", db.dialect.name) + db = NewDB(DBWithDialect(dialect.SQLite)) + fmt.Printf("case2 dialect: %s\n", db.dialect.Name) // case3 share registry among DB - registry := NewTagMetaRegistry() + registry := model.NewTagMetaRegistry() db1 := NewDB(DBWithMetaRegistry(registry)) db2 := NewDB(DBWithMetaRegistry(registry)) fmt.Printf("case3 same registry: %v", db1.metaRegistry == db2.metaRegistry) diff --git a/delete.go b/delete.go index 87fd113e..3e3bafab 100644 --- a/delete.go +++ b/delete.go @@ -32,7 +32,7 @@ func (d *Deleter) Build() (*Query, error) { return nil, err } - d.quote(d.meta.tableName) + d.quote(d.meta.TableName) if len(d.where) > 0 { _, err = d.buffer.WriteString(" WHERE ") if err != nil { diff --git a/go.mod b/go.mod index a4131239..fb61635f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/gotomicro/eorm -go 1.17 +go 1.18 require ( github.com/stretchr/testify v1.7.0 diff --git a/insert.go b/insert.go index ea2abff8..2e37b405 100644 --- a/insert.go +++ b/insert.go @@ -17,6 +17,7 @@ package eorm import ( "errors" "fmt" + "github.com/gotomicro/eorm/internal/model" "reflect" ) @@ -42,7 +43,7 @@ func (i *Inserter) Build() (*Query, error) { if err != nil { return &Query{}, err } - i.quote(i.meta.tableName) + i.quote(i.meta.TableName) i.buffer.WriteString("(") fields, err := i.buildColumns() if err != nil { @@ -54,9 +55,9 @@ func (i *Inserter) Build() (*Query, error) { i.buffer.WriteString("(") refVal := reflect.ValueOf(value).Elem() for j, v := range fields { - field := refVal.FieldByName(v.fieldName) + field := refVal.FieldByName(v.FieldName) if !field.IsValid() { - return &Query{}, fmt.Errorf("invalid column %s", v.fieldName) + return &Query{}, fmt.Errorf("invalid column %s", v.FieldName) } val := field.Interface() i.parameter(val) @@ -101,24 +102,24 @@ func (i *Inserter) OnConflict(cs ...string) *PgSQLUpserter { panic("implement me") } -func (i *Inserter) buildColumns() ([]*ColumnMeta, error) { - cs := i.meta.columns +func (i *Inserter) buildColumns() ([]*model.ColumnMeta, error) { + cs := i.meta.Columns if len(i.columns) != 0 { - cs = make([]*ColumnMeta, 0, len(i.columns)) + cs = make([]*model.ColumnMeta, 0, len(i.columns)) for index, value := range i.columns { - v, isOk := i.meta.fieldMap[value] + v, isOk := i.meta.FieldMap[value] if !isOk { return cs, fmt.Errorf("invalid column %s", value) } - i.quote(v.columnName) + i.quote(v.ColumnName) if index != len(i.columns)-1 { i.comma() } cs = append(cs, v) } } else { - for index, value := range i.meta.columns { - i.quote(value.columnName) + for index, value := range i.meta.Columns { + i.quote(value.ColumnName) if index != len(cs)-1 { i.comma() } diff --git a/dialect.go b/internal/dialect/dialect.go similarity index 87% rename from dialect.go rename to internal/dialect/dialect.go index 9678f3e5..6959abb1 100644 --- a/dialect.go +++ b/internal/dialect/dialect.go @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package eorm +package dialect // Dialect specify config or behavior of special SQL dialects type Dialect struct { - name string + Name string // in MYSQL, it's "`" - quote byte + Quote byte } var ( MySQL = Dialect{ - name: "MySQL", - quote: '`', + Name: "MySQL", + Quote: '`', } SQLite = Dialect{ - name: "SQLite", - quote: '`', + Name: "SQLite", + Quote: '`', } ) diff --git a/internal/error.go b/internal/error/error.go similarity index 87% rename from internal/error.go rename to internal/error/error.go index 7cea7695..e9fdd1db 100644 --- a/internal/error.go +++ b/internal/error/error.go @@ -12,19 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package error import ( "errors" "fmt" ) -var errValueNotSet = errors.New("value unset") +var ( + errValueNotSet = errors.New("value unset") +) + // NewInvalidColumnError returns an error represents invalid field name // TODO(do we need errors pkg?) func NewInvalidColumnError(field string) error { - return fmt.Errorf("eql: invalid column name %s, " + + return fmt.Errorf("eorm: invalid column name %s, " + "it must be a valid field name of structure", field) } diff --git a/model.go b/internal/model/model.go similarity index 70% rename from model.go rename to internal/model/model.go index da6b792a..ac90c433 100644 --- a/model.go +++ b/internal/model/model.go @@ -12,31 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -package eorm +package model import ( "reflect" "strings" "sync" - - "github.com/gotomicro/eorm/internal" + "unicode" ) // TableMeta represents data model, or a table type TableMeta struct { - tableName string - columns []*ColumnMeta - fieldMap map[string]*ColumnMeta - typ reflect.Type + TableName string + Columns []*ColumnMeta + FieldMap map[string]*ColumnMeta + Typ reflect.Type } // ColumnMeta represents model's field, or column type ColumnMeta struct { - columnName string - fieldName string - typ reflect.Type - isPrimaryKey bool - isAutoIncrement bool + ColumnName string + FieldName string + Typ reflect.Type + IsPrimaryKey bool + IsAutoIncrement bool } // TableMetaOption represents options of TableMeta, this options will cover default cover. @@ -48,6 +47,10 @@ type MetaRegistry interface { Register(table interface{}, opts ...TableMetaOption) (*TableMeta, error) } +func NewMetaRegistry() MetaRegistry { + return &tagMetaRegistry{} +} + // tagMetaRegistry is the default implementation based on tag eql type tagMetaRegistry struct { metas sync.Map @@ -94,20 +97,20 @@ func (t *tagMetaRegistry) Register(table interface{}, opts ...TableMetaOption) ( continue } columnMeta := &ColumnMeta{ - columnName: internal.UnderscoreName(structField.Name), - fieldName: structField.Name, - typ: structField.Type, - isAutoIncrement: isAuto, - isPrimaryKey: isKey, + ColumnName: underscoreName(structField.Name), + FieldName: structField.Name, + Typ: structField.Type, + IsAutoIncrement: isAuto, + IsPrimaryKey: isKey, } columnMetas = append(columnMetas, columnMeta) - fieldMap[columnMeta.fieldName] = columnMeta + fieldMap[columnMeta.FieldName] = columnMeta } tableMeta := &TableMeta{ - columns: columnMetas, - tableName: internal.UnderscoreName(v.Name()), - typ: rtype, - fieldMap: fieldMap, + Columns: columnMetas, + TableName: underscoreName(v.Name()), + Typ: rtype, + FieldMap: fieldMap, } for _, o := range opts { o(tableMeta) @@ -121,17 +124,34 @@ func IgnoreFieldsOption(fieldNames ...string) TableMetaOption { return func(meta *TableMeta) { for _, field := range fieldNames { // has field in the TableMeta - if _, ok := meta.fieldMap[field]; ok { + if _, ok := meta.FieldMap[field]; ok { // delete field in columns slice - for index, column := range meta.columns { - if column.fieldName == field { - meta.columns = append(meta.columns[:index], meta.columns[index+1:]...) + for index, column := range meta.Columns { + if column.FieldName == field { + meta.Columns = append(meta.Columns[:index], meta.Columns[index+1:]...) break } } // delete field in fieldMap - delete(meta.fieldMap, field) + delete(meta.FieldMap, field) } } } } + +// underscoreName function mainly converts upper case to lower case and adds an underscore in between +func underscoreName(tableName string) string { + var buf []byte + for i, v := range tableName { + if unicode.IsUpper(v) { + if i != 0 { + buf = append(buf, '_') + } + buf = append(buf, byte(unicode.ToLower(v))) + } else { + buf = append(buf, byte(v)) + } + + } + return string(buf) +} \ No newline at end of file diff --git a/model_test.go b/internal/model/model_test.go similarity index 58% rename from model_test.go rename to internal/model/model_test.go index c0d30924..4143f2c7 100644 --- a/model_test.go +++ b/internal/model/model_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package eorm +package model import ( "fmt" @@ -29,32 +29,32 @@ func TestTagMetaRegistry(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, 4, len(meta.columns)) - assert.Equal(t, 4, len(meta.fieldMap)) - assert.Equal(t, reflect.TypeOf(tm), meta.typ) - assert.Equal(t, "test_model", meta.tableName) - - idMeta := meta.fieldMap["Id"] - assert.Equal(t, "id", idMeta.columnName) - assert.Equal(t, "Id", idMeta.fieldName) - assert.Equal(t, reflect.TypeOf(int64(0)), idMeta.typ) - assert.True(t, idMeta.isAutoIncrement) - assert.True(t, idMeta.isPrimaryKey) - - idMetaFistName := meta.fieldMap["FirstName"] - assert.Equal(t, "first_name", idMetaFistName.columnName) - assert.Equal(t, "FirstName", idMetaFistName.fieldName) - assert.Equal(t, reflect.TypeOf(string("")), idMetaFistName.typ) - - idMetaLastName := meta.fieldMap["LastName"] - assert.Equal(t, "last_name", idMetaLastName.columnName) - assert.Equal(t, "LastName", idMetaLastName.fieldName) - assert.Equal(t, reflect.TypeOf((*string)(nil)), idMetaLastName.typ) - - idMetaLastAge := meta.fieldMap["Age"] - assert.Equal(t, "age", idMetaLastAge.columnName) - assert.Equal(t, "Age", idMetaLastAge.fieldName) - assert.Equal(t, reflect.TypeOf(int8(0)), idMetaLastAge.typ) + assert.Equal(t, 4, len(meta.Columns)) + assert.Equal(t, 4, len(meta.FieldMap)) + assert.Equal(t, reflect.TypeOf(tm), meta.Typ) + assert.Equal(t, "test_model", meta.TableName) + + idMeta := meta.FieldMap["Id"] + assert.Equal(t, "id", idMeta.ColumnName) + assert.Equal(t, "Id", idMeta.FieldName) + assert.Equal(t, reflect.TypeOf(int64(0)), idMeta.Typ) + assert.True(t, idMeta.IsAutoIncrement) + assert.True(t, idMeta.IsPrimaryKey) + + idMetaFistName := meta.FieldMap["FirstName"] + assert.Equal(t, "first_name", idMetaFistName.ColumnName) + assert.Equal(t, "FirstName", idMetaFistName.FieldName) + assert.Equal(t, reflect.TypeOf(string("")), idMetaFistName.Typ) + + idMetaLastName := meta.FieldMap["LastName"] + assert.Equal(t, "last_name", idMetaLastName.ColumnName) + assert.Equal(t, "LastName", idMetaLastName.FieldName) + assert.Equal(t, reflect.TypeOf((*string)(nil)), idMetaLastName.Typ) + + idMetaLastAge := meta.FieldMap["Age"] + assert.Equal(t, "age", idMetaLastAge.ColumnName) + assert.Equal(t, "Age", idMetaLastAge.FieldName) + assert.Equal(t, reflect.TypeOf(int8(0)), idMetaLastAge.Typ) } @@ -65,21 +65,21 @@ func TestIgnoreFieldsOption(t *testing.T) { if err != nil { t.Fatal(err) } - assert.Equal(t, 1, len(meta.columns)) - assert.Equal(t, 1, len(meta.fieldMap)) - assert.Equal(t, reflect.TypeOf(tm), meta.typ) - assert.Equal(t, "test_ignore_model", meta.tableName) + assert.Equal(t, 1, len(meta.Columns)) + assert.Equal(t, 1, len(meta.FieldMap)) + assert.Equal(t, reflect.TypeOf(tm), meta.Typ) + assert.Equal(t, "test_ignore_model", meta.TableName) - _, hasId := meta.fieldMap["Id"] + _, hasId := meta.FieldMap["Id"] assert.False(t, hasId) - _, hasFirstName := meta.fieldMap["FirstName"] + _, hasFirstName := meta.FieldMap["FirstName"] assert.False(t, hasFirstName) - _, hasAge := meta.fieldMap["Age"] + _, hasAge := meta.FieldMap["Age"] assert.False(t, hasAge) - _, hasLastName := meta.fieldMap["LastName"] + _, hasLastName := meta.FieldMap["LastName"] assert.True(t, hasLastName) } @@ -94,7 +94,7 @@ func ExampleMetaRegistry_Get() { tm := &TestModel{} registry := &tagMetaRegistry{} meta, _ := registry.Get(tm) - fmt.Printf("table name: %v\n", meta.tableName) + fmt.Printf("table name: %v\n", meta.TableName) // Output: // table name: test_model @@ -109,7 +109,7 @@ func ExampleMetaRegistry_Register() { case1: table name:%s column names:%s,%s,%s,%s -`, meta.tableName, meta.columns[0].columnName, meta.columns[1].columnName, meta.columns[2].columnName, meta.columns[3].columnName) +`, meta.TableName, meta.Columns[0].ColumnName, meta.Columns[1].ColumnName, meta.Columns[2].ColumnName, meta.Columns[3].ColumnName) // case2 use Tag to ignore field tim := &TestIgnoreModel{} @@ -119,7 +119,7 @@ case1: case2: table name:%s column names:%s,%s -`, meta.tableName, meta.columns[0].columnName, meta.columns[1].columnName) +`, meta.TableName, meta.Columns[0].ColumnName, meta.Columns[1].ColumnName) // case3 use IgnoreFieldOption to ignore field tim = &TestIgnoreModel{} @@ -129,7 +129,7 @@ case2: case3: table name:%s column names:%s -`, meta.tableName, meta.columns[0].columnName) +`, meta.TableName, meta.Columns[0].ColumnName) // Output: // case1: @@ -144,3 +144,10 @@ case3: // table name:test_ignore_model // column names:last_name } + +type TestModel struct { + Id int64 `eql:"auto_increment,primary_key"` + FirstName string + Age int8 + LastName *string +} \ No newline at end of file diff --git a/internal/value/reflect/value.go b/internal/value/reflect/value.go new file mode 100644 index 00000000..96c8a706 --- /dev/null +++ b/internal/value/reflect/value.go @@ -0,0 +1,44 @@ +// Copyright 2022 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reflect + +import ( + "fmt" + "github.com/gotomicro/eorm/internal/value" + "reflect" +) + + +// reflectValue 基于反射的 Value +type reflectValue struct { + val reflect.Value +} + +// NewValue 返回一个封装好的,基于反射实现的 Value +// 输入 val 必须是一个指向结构体实例的指针,而不能是任何其它类型 +func NewValue(val interface{}) value.Value { + return reflectValue{ + val: reflect.ValueOf(val).Elem(), + } +} + +// Field 返回字段值 +func (r reflectValue) Field(name string) (interface{}, error) { + res := r.val.FieldByName(name) + if res == (reflect.Value{}) { + return nil, fmt.Errorf("eorm: 找不到字段 %s", name) + } + return res.Interface(), nil +} diff --git a/internal/value/reflect/value_test.go b/internal/value/reflect/value_test.go new file mode 100644 index 00000000..8fc841ec --- /dev/null +++ b/internal/value/reflect/value_test.go @@ -0,0 +1,88 @@ +// Copyright 2022 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reflect + +import ( + "errors" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestReflectValue_Field(t *testing.T) { + ins := NewValue(&TestModel{ + Id: 13, + }) + testCases := []struct{ + name string + field string + wantVal interface{} + wantError error + } { + { + name: "正常值", + field: "Id", + wantVal: int64(13), + }, + { + name: "不存在字段", + field: "InvalidField", + wantError: errors.New("eorm: 找不到字段 InvalidField"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + val, err := ins.Field(tc.field) + assert.Equal(t, tc.wantError, err) + if tc.wantError != nil { + return + } + assert.Equal(t, tc.wantVal, val) + }) + } +} + +func FuzzReflectValue_Field(f *testing.F) { + f.Fuzz(func(t *testing.T, wantId int64) { + val := NewValue(&TestModel{ + Id: wantId, + }) + id, err := val.Field("Id") + assert.Nil(t, err) + assert.Equal(t, wantId, id) + }) +} + +func BenchmarkReflectValue_Field(b *testing.B) { + ins := NewValue(&TestModel{ + Id: 13, + }) + for i := 0; i < b.N; i++ { + val, err := ins.Field("Id") + assert.Nil(b, err) + assert.Equal(b, int64(13), val) + } +} + +// TODO +// 添加更多的字段,覆盖以下类型 +// uint 家族 +// int 家族 +// float 家族 +// []byte +// string +type TestModel struct { + Id int64 +} \ No newline at end of file diff --git a/internal/common_fun.go b/internal/value/value.go similarity index 55% rename from internal/common_fun.go rename to internal/value/value.go index 892921bb..8f20d044 100644 --- a/internal/common_fun.go +++ b/internal/value/value.go @@ -1,4 +1,4 @@ -// Copyright 2021 gotomicro +// Copyright 2022 gotomicro // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,25 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package internal +package value + +// Value 是对结构体实例的内部抽象 +type Value interface { + // Field 访问结构体字段 + Field(name string) (interface{}, error) +} + -import ( - "unicode" -) -// UnderscoreName function mainly converts upper case to lower case and adds an underscore in between -func UnderscoreName(tableName string) string { - buf := []byte{} - for i, v := range tableName { - if unicode.IsUpper(v) { - if i != 0 { - buf = append(buf, '_') - } - buf = append(buf, byte(unicode.ToLower(v))) - } else { - buf = append(buf, byte(v)) - } - } - return string(buf) -} diff --git a/select.go b/select.go index dd46d158..ac072261 100644 --- a/select.go +++ b/select.go @@ -15,7 +15,7 @@ package eorm import ( - "github.com/gotomicro/eorm/internal" + error2 "github.com/gotomicro/eorm/internal/error" "github.com/valyala/bytebufferpool" ) @@ -57,7 +57,7 @@ func (s *Selector) Build() (*Query, error) { } } _, _ = s.buffer.WriteString(" FROM ") - s.quote(s.meta.tableName) + s.quote(s.meta.TableName) if len(s.where) > 0 { _, _ = s.buffer.WriteString(" WHERE ") err = s.buildPredicates(s.where) @@ -111,11 +111,11 @@ func (s *Selector) buildOrderBy() error { s.comma() } for _, c := range ob.fields { - cMeta, ok := s.meta.fieldMap[c] + cMeta, ok := s.meta.FieldMap[c] if !ok { - return internal.NewInvalidColumnError(c) + return error2.NewInvalidColumnError(c) } - s.quote(cMeta.columnName) + s.quote(cMeta.ColumnName) } s.space() _, _ = s.buffer.WriteString(ob.order) @@ -126,25 +126,25 @@ func (s *Selector) buildOrderBy() error { func (s *Selector) buildGroupBy() error { _, _ = s.buffer.WriteString(" GROUP BY ") for i, gb := range s.groupBy { - cMeta, ok := s.meta.fieldMap[gb] + cMeta, ok := s.meta.FieldMap[gb] if !ok { - return internal.NewInvalidColumnError(gb) + return error2.NewInvalidColumnError(gb) } if i > 0 { s.comma() } - s.quote(cMeta.columnName) + s.quote(cMeta.ColumnName) } return nil } func (s *Selector) buildAllColumns() { - for i, cMeta := range s.meta.columns { + for i, cMeta := range s.meta.Columns { if i > 0 { s.comma() } // it should never return error, we can safely ignore it - _ = s.buildColumn(cMeta.fieldName, "") + _ = s.buildColumn(cMeta.FieldName, "") } } @@ -185,12 +185,12 @@ func (s *Selector) buildSelectedList() error { func (s *Selector) selectAggregate(aggregate Aggregate) error { _, _ = s.buffer.WriteString(aggregate.fn) _ = s.buffer.WriteByte('(') - cMeta, ok := s.meta.fieldMap[aggregate.arg] + cMeta, ok := s.meta.FieldMap[aggregate.arg] s.aliases[aggregate.alias] = struct{}{} if !ok { - return internal.NewInvalidColumnError(aggregate.arg) + return error2.NewInvalidColumnError(aggregate.arg) } - s.quote(cMeta.columnName) + s.quote(cMeta.ColumnName) _ = s.buffer.WriteByte(')') if aggregate.alias != "" { if _, ok := s.aliases[aggregate.alias]; ok { @@ -202,11 +202,11 @@ func (s *Selector) selectAggregate(aggregate Aggregate) error { } func (s *Selector) buildColumn(field, alias string) error { - cMeta, ok := s.meta.fieldMap[field] + cMeta, ok := s.meta.FieldMap[field] if !ok { - return internal.NewInvalidColumnError(field) + return error2.NewInvalidColumnError(field) } - s.quote(cMeta.columnName) + s.quote(cMeta.ColumnName) if alias != "" { s.aliases[alias] = struct{}{} _, _ = s.buffer.WriteString(" AS ") diff --git a/select_test.go b/select_test.go index 53ea32d2..4256990d 100644 --- a/select_test.go +++ b/select_test.go @@ -16,7 +16,7 @@ package eorm import ( "fmt" - "github.com/gotomicro/eorm/internal" + error2 "github.com/gotomicro/eorm/internal/error" "github.com/stretchr/testify/assert" "testing" ) @@ -52,7 +52,7 @@ func TestSelectable(t *testing.T) { { name: "invalid columns", builder: NewSelector(db).Select(Columns("Invalid"), Raw("AVG(DISTINCT `age`)")).From(&TestModel{}), - wantErr: internal.NewInvalidColumnError("Invalid"), + wantErr: error2.NewInvalidColumnError("Invalid"), }, { name: "order by", @@ -62,7 +62,7 @@ func TestSelectable(t *testing.T) { { name: "order by invalid column", builder: NewSelector(db).From(&TestModel{}).OrderBy(ASC("Invalid"), DESC("Id")), - wantErr: internal.NewInvalidColumnError("Invalid"), + wantErr: error2.NewInvalidColumnError("Invalid"), }, { name: "group by", @@ -72,7 +72,7 @@ func TestSelectable(t *testing.T) { { name: "group by invalid column", builder: NewSelector(db).From(&TestModel{}).GroupBy("Invalid", "Id"), - wantErr: internal.NewInvalidColumnError("Invalid"), + wantErr: error2.NewInvalidColumnError("Invalid"), }, { name: "offset", @@ -117,7 +117,7 @@ func TestSelectable(t *testing.T) { { name: "invalid alias in having", builder: NewSelector(db).Select(Columns("Id"), Columns("FirstName"), Avg("Age").As("avg_age")).From(&TestModel{}).GroupBy("FirstName").Having(C("Invalid").LT(20)), - wantErr: internal.NewInvalidColumnError("Invalid"), + wantErr: error2.NewInvalidColumnError("Invalid"), }, } @@ -172,6 +172,6 @@ func ExampleSelector_Having() { // SQL: SELECT `id`,`first_name`,AVG(`age`) AS `avg_age` FROM `test_model` GROUP BY `first_name` HAVING `avg_age`