diff --git a/.CHANGELOG.md b/.CHANGELOG.md index badd7ec..a6aad4d 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -23,6 +23,7 @@ - [eorm: 分库分表: 范围查询支持](https://github.com/ecodeclub/eorm/pull/178) - [eorm: 分库分表: 结果集处理--聚合函数(不含GroupBy子句)](https://github.com/ecodeclub/eorm/pull/187) - [eorm: 修复单条查询时连接泄露问题](https://github.com/ecodeclub/eorm/pull/188) +- [eorm: 分库分表: 结果集处理--聚合函数(含GroupBy子句)](https://github.com/ecodeclub/eorm/pull/193) - [eorm: 分库分表: NOT 支持](https://github.com/ecodeclub/eorm/pull/191) ## v0.0.1: diff --git a/go.mod b/go.mod index d229c4a..7a794c1 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/go-sql-driver/mysql v1.6.0 github.com/gotomicro/ekit v0.0.6 github.com/mattn/go-sqlite3 v1.14.15 - github.com/stretchr/testify v1.7.1 + github.com/stretchr/testify v1.8.1 github.com/valyala/bytebufferpool v1.0.0 go.uber.org/multierr v1.9.0 golang.org/x/sync v0.1.0 @@ -15,7 +15,10 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.uber.org/atomic v1.7.0 // indirect + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e9eb1d2..edf204d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -7,14 +8,24 @@ github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfC github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gotomicro/ekit v0.0.6 h1:Tw3vcx8hltUzFmK7zkp6/5OGlE+ceuq6wha7KxBfpaA= github.com/gotomicro/ekit v0.0.6/go.mod h1:LpstTheKiI/j532rejAlTwPRemwFQXhyqdH6lpzr4wk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= @@ -23,8 +34,9 @@ go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/merger/aggregatemerger/aggregator/avg.go b/internal/merger/aggregatemerger/aggregator/avg.go index c1de207..d6aabf0 100644 --- a/internal/merger/aggregatemerger/aggregator/avg.go +++ b/internal/merger/aggregatemerger/aggregator/avg.go @@ -17,19 +17,21 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) // AVG 用于求平均值,通过sum/count求得。 // AVG 我们并不能预期在不同的数据库上,精度会不会损失,以及损失的话会有多少的损失。这很大程度上跟数据库类型,数据库驱动实现都有关 type AVG struct { - sumColumnInfo ColumnInfo - countColumnInfo ColumnInfo + sumColumnInfo merger.ColumnInfo + countColumnInfo merger.ColumnInfo avgName string } // NewAVG sumInfo是sum的信息,countInfo是count的信息,avgName用于Column方法 -func NewAVG(sumInfo ColumnInfo, countInfo ColumnInfo, avgName string) *AVG { +func NewAVG(sumInfo merger.ColumnInfo, countInfo merger.ColumnInfo, avgName string) *AVG { return &AVG{ sumColumnInfo: sumInfo, countColumnInfo: countInfo, diff --git a/internal/merger/aggregatemerger/aggregator/avg_test.go b/internal/merger/aggregatemerger/aggregator/avg_test.go index 23f67eb..9790171 100644 --- a/internal/merger/aggregatemerger/aggregator/avg_test.go +++ b/internal/merger/aggregatemerger/aggregator/avg_test.go @@ -17,6 +17,8 @@ package aggregator import ( "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" "github.com/stretchr/testify/assert" @@ -82,7 +84,7 @@ func TestAvg_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - avg := NewAVG(NewColumnInfo(tc.index[0], "SUM(grade)"), NewColumnInfo(tc.index[1], "COUNT(grade)"), "AVG(grade)") + avg := NewAVG(merger.NewColumnInfo(tc.index[0], "SUM(grade)"), merger.NewColumnInfo(tc.index[1], "COUNT(grade)"), "AVG(grade)") val, err := avg.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/count.go b/internal/merger/aggregatemerger/aggregator/count.go index 6d08924..908837e 100644 --- a/internal/merger/aggregatemerger/aggregator/count.go +++ b/internal/merger/aggregatemerger/aggregator/count.go @@ -17,11 +17,13 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) type Count struct { - countInfo ColumnInfo + countInfo merger.ColumnInfo } func (s *Count) Aggregate(cols [][]any) (any, error) { @@ -49,7 +51,7 @@ func (s *Count) ColumnName() string { return s.countInfo.Name } -func NewCount(info ColumnInfo) *Count { +func NewCount(info merger.ColumnInfo) *Count { return &Count{ countInfo: info, } diff --git a/internal/merger/aggregatemerger/aggregator/count_test.go b/internal/merger/aggregatemerger/aggregator/count_test.go index 136d49f..81583c3 100644 --- a/internal/merger/aggregatemerger/aggregator/count_test.go +++ b/internal/merger/aggregatemerger/aggregator/count_test.go @@ -17,6 +17,8 @@ package aggregator import ( "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" "github.com/stretchr/testify/assert" @@ -78,7 +80,7 @@ func TestCount_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - count := NewCount(NewColumnInfo(tc.countIndex, "COUNT(id)")) + count := NewCount(merger.NewColumnInfo(tc.countIndex, "COUNT(id)")) val, err := count.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/max.go b/internal/merger/aggregatemerger/aggregator/max.go index 8f6464e..a8757b1 100644 --- a/internal/merger/aggregatemerger/aggregator/max.go +++ b/internal/merger/aggregatemerger/aggregator/max.go @@ -17,11 +17,13 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) type Max struct { - maxColumnInfo ColumnInfo + maxColumnInfo merger.ColumnInfo } func (m *Max) Aggregate(cols [][]any) (any, error) { @@ -49,7 +51,7 @@ func (m *Max) ColumnName() string { return m.maxColumnInfo.Name } -func NewMax(info ColumnInfo) *Max { +func NewMax(info merger.ColumnInfo) *Max { return &Max{ maxColumnInfo: info, } diff --git a/internal/merger/aggregatemerger/aggregator/max_test.go b/internal/merger/aggregatemerger/aggregator/max_test.go index f74dbb7..54071e4 100644 --- a/internal/merger/aggregatemerger/aggregator/max_test.go +++ b/internal/merger/aggregatemerger/aggregator/max_test.go @@ -17,6 +17,8 @@ package aggregator import ( "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" "github.com/stretchr/testify/assert" ) @@ -77,7 +79,7 @@ func TestMax_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - max := NewMax(NewColumnInfo(tc.maxIndex, "MAX(id)")) + max := NewMax(merger.NewColumnInfo(tc.maxIndex, "MAX(id)")) val, err := max.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/min.go b/internal/merger/aggregatemerger/aggregator/min.go index 61bde7a..321a62f 100644 --- a/internal/merger/aggregatemerger/aggregator/min.go +++ b/internal/merger/aggregatemerger/aggregator/min.go @@ -17,11 +17,13 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) type Min struct { - minColumnInfo ColumnInfo + minColumnInfo merger.ColumnInfo } func (m *Min) Aggregate(cols [][]any) (any, error) { @@ -50,7 +52,7 @@ func (m *Min) ColumnName() string { return m.minColumnInfo.Name } -func NewMin(info ColumnInfo) *Min { +func NewMin(info merger.ColumnInfo) *Min { return &Min{ minColumnInfo: info, } diff --git a/internal/merger/aggregatemerger/aggregator/min_test.go b/internal/merger/aggregatemerger/aggregator/min_test.go index 52f40f6..917c584 100644 --- a/internal/merger/aggregatemerger/aggregator/min_test.go +++ b/internal/merger/aggregatemerger/aggregator/min_test.go @@ -17,6 +17,8 @@ package aggregator import ( "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" "github.com/stretchr/testify/assert" ) @@ -77,7 +79,7 @@ func TestMin_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - min := NewMin(NewColumnInfo(tc.minIndex, "MIN(id)")) + min := NewMin(merger.NewColumnInfo(tc.minIndex, "MIN(id)")) val, err := min.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/sum.go b/internal/merger/aggregatemerger/aggregator/sum.go index 768494c..b67f122 100644 --- a/internal/merger/aggregatemerger/aggregator/sum.go +++ b/internal/merger/aggregatemerger/aggregator/sum.go @@ -17,11 +17,13 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) type Sum struct { - sumColumnInfo ColumnInfo + sumColumnInfo merger.ColumnInfo } func (s *Sum) Aggregate(cols [][]any) (any, error) { @@ -50,7 +52,7 @@ func (s *Sum) ColumnName() string { return s.sumColumnInfo.Name } -func NewSum(info ColumnInfo) *Sum { +func NewSum(info merger.ColumnInfo) *Sum { return &Sum{ sumColumnInfo: info, } diff --git a/internal/merger/aggregatemerger/aggregator/sum_test.go b/internal/merger/aggregatemerger/aggregator/sum_test.go index efa22cf..c9edf81 100644 --- a/internal/merger/aggregatemerger/aggregator/sum_test.go +++ b/internal/merger/aggregatemerger/aggregator/sum_test.go @@ -17,6 +17,8 @@ package aggregator import ( "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" "github.com/stretchr/testify/assert" @@ -79,7 +81,7 @@ func TestSum_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - sum := NewSum(NewColumnInfo(tc.sumIndex, "SUM(id)")) + sum := NewSum(merger.NewColumnInfo(tc.sumIndex, "SUM(id)")) val, err := sum.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/type.go b/internal/merger/aggregatemerger/aggregator/type.go index 6fb58cd..8a29c1e 100644 --- a/internal/merger/aggregatemerger/aggregator/type.go +++ b/internal/merger/aggregatemerger/aggregator/type.go @@ -24,15 +24,3 @@ type Aggregator interface { // ColumnName 聚合函数的别名 ColumnName() string } - -type ColumnInfo struct { - Index int - Name string -} - -func NewColumnInfo(index int, name string) ColumnInfo { - return ColumnInfo{ - Index: index, - Name: name, - } -} diff --git a/internal/merger/aggregatemerger/merger.go b/internal/merger/aggregatemerger/merger.go index b511283..0c3998d 100644 --- a/internal/merger/aggregatemerger/merger.go +++ b/internal/merger/aggregatemerger/merger.go @@ -17,19 +17,17 @@ package aggregatemerger import ( "context" "database/sql" - "reflect" "sync" _ "unsafe" + "github.com/ecodeclub/eorm/internal/merger/utils" + "github.com/ecodeclub/eorm/internal/merger" "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" "github.com/ecodeclub/eorm/internal/merger/internal/errs" "go.uber.org/multierr" ) -//go:linkname convertAssign database/sql.convertAssign -func convertAssign(dest, src any) error - // Merger 该实现不支持group by操作,并且聚合函数查询应该只返回一行数据。 type Merger struct { aggregators []aggregator.Aggregator @@ -151,32 +149,14 @@ func (r *Rows) getSqlRowsData() ([][]any, error) { return rowsData, nil } func (r *Rows) getSqlRowData(row *sql.Rows) ([]any, error) { - colsInfo, err := row.ColumnTypes() - if err != nil { - return nil, err - } - // colsData 表示一个sql.Rows的数据 - colsData := make([]any, 0, len(colsInfo)) + + var colsData []any + var err error if row.Next() { - // 拿到sql.Rows字段的类型然后初始化 - for _, colInfo := range colsInfo { - typ := colInfo.ScanType() - // sqlite3的驱动返回的是指针。循环的去除指针 - for typ.Kind() == reflect.Pointer { - typ = typ.Elem() - } - newData := reflect.New(typ).Interface() - colsData = append(colsData, newData) - } - // 通过Scan赋值 - err = row.Scan(colsData...) + colsData, err = utils.Scan(row) if err != nil { return nil, err } - // 去掉reflect.New的指针 - for i := 0; i < len(colsData); i++ { - colsData[i] = reflect.ValueOf(colsData[i]).Elem().Interface() - } } else { // sql.Rows迭代过程中发生报错,返回报错 if row.Err() != nil { @@ -201,7 +181,7 @@ func (r *Rows) Scan(dest ...any) error { return errs.ErrMergerScanNotNext } for i := 0; i < len(dest); i++ { - err := convertAssign(dest[i], r.cur[i]) + err := utils.ConvertAssign(dest[i], r.cur[i]) if err != nil { return err } diff --git a/internal/merger/aggregatemerger/merger_test.go b/internal/merger/aggregatemerger/merger_test.go index 8407056..ea12b12 100644 --- a/internal/merger/aggregatemerger/merger_test.go +++ b/internal/merger/aggregatemerger/merger_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/DATA-DOG/go-sqlmock" "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" "github.com/ecodeclub/eorm/internal/merger/internal/errs" @@ -141,7 +143,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -170,7 +172,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -200,7 +202,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewMax(aggregator.NewColumnInfo(0, "MAX(id)")), + aggregator.NewMax(merger.NewColumnInfo(0, "MAX(id)")), } }, }, @@ -229,7 +231,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewMin(aggregator.NewColumnInfo(0, "MIN(id)")), + aggregator.NewMin(merger.NewColumnInfo(0, "MIN(id)")), } }, }, @@ -258,7 +260,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(0, "COUNT(id)")), + aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), } }, }, @@ -289,7 +291,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewAVG(aggregator.NewColumnInfo(0, "SUM(grade)"), aggregator.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), } }, }, @@ -323,11 +325,11 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(0, "COUNT(id)")), - aggregator.NewMax(aggregator.NewColumnInfo(1, "MAX(id)")), - aggregator.NewMin(aggregator.NewColumnInfo(2, "MIN(id)")), - aggregator.NewSum(aggregator.NewColumnInfo(3, "SUM(id)")), - aggregator.NewAVG(aggregator.NewColumnInfo(4, "SUM(grade)"), aggregator.NewColumnInfo(5, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), + aggregator.NewMax(merger.NewColumnInfo(1, "MAX(id)")), + aggregator.NewMin(merger.NewColumnInfo(2, "MIN(id)")), + aggregator.NewSum(merger.NewColumnInfo(3, "SUM(id)")), + aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(5, "COUNT(grade)"), "AVG(grade)"), } }, }, @@ -360,13 +362,13 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewAVG(aggregator.NewColumnInfo(0, "SUM(grade)"), aggregator.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewSum(aggregator.NewColumnInfo(2, "SUM(grade)")), - aggregator.NewAVG(aggregator.NewColumnInfo(4, "SUM(grade)"), aggregator.NewColumnInfo(3, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewMin(aggregator.NewColumnInfo(5, "MIN(id)")), - aggregator.NewMin(aggregator.NewColumnInfo(6, "MIN(userid)")), - aggregator.NewMax(aggregator.NewColumnInfo(7, "MAX(id)")), - aggregator.NewCount(aggregator.NewColumnInfo(8, "COUNT(id)")), + aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(grade)")), + aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(3, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewMin(merger.NewColumnInfo(5, "MIN(id)")), + aggregator.NewMin(merger.NewColumnInfo(6, "MIN(userid)")), + aggregator.NewMax(merger.NewColumnInfo(7, "MAX(id)")), + aggregator.NewCount(merger.NewColumnInfo(8, "COUNT(id)")), } }, }, @@ -401,7 +403,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -433,7 +435,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -465,7 +467,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -491,7 +493,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -545,7 +547,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { }, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(0, "COUNT(id)")), + aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), } }(), wantErr: nextMockErr, @@ -597,7 +599,7 @@ func (ms *MergerSuite) TestRows_Close() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2).CloseError(newCloseMockErr("db02"))) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3).CloseError(newCloseMockErr("db03"))) - merger := NewMerger(aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)"))) + merger := NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]*sql.Rows, 0, len(dbs)) for _, db := range dbs { @@ -646,11 +648,11 @@ func (ms *MergerSuite) TestRows_Columns() { ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12)) aggregators := []aggregator.Aggregator{ - aggregator.NewAVG(aggregator.NewColumnInfo(0, "SUM(grade)"), aggregator.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewSum(aggregator.NewColumnInfo(2, "SUM(id)")), - aggregator.NewMin(aggregator.NewColumnInfo(3, "MIN(id)")), - aggregator.NewMax(aggregator.NewColumnInfo(4, "MAX(id)")), - aggregator.NewCount(aggregator.NewColumnInfo(5, "COUNT(id)")), + aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), + aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")), + aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")), + aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")), } merger := NewMerger(aggregators...) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} @@ -691,7 +693,7 @@ func (ms *MergerSuite) TestMerger_Merge() { { name: "超时", merger: func() *Merger { - return NewMerger(aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(context.Background(), 0) @@ -711,7 +713,7 @@ func (ms *MergerSuite) TestMerger_Merge() { { name: "sqlRows列表元素个数为0", merger: func() *Merger { - return NewMerger(aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) @@ -725,7 +727,7 @@ func (ms *MergerSuite) TestMerger_Merge() { { name: "sqlRows列表有nil", merger: func() *Merger { - return NewMerger(aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/merger/groupby_merger/aggregator_merger.go b/internal/merger/groupby_merger/aggregator_merger.go new file mode 100644 index 0000000..1d60e87 --- /dev/null +++ b/internal/merger/groupby_merger/aggregator_merger.go @@ -0,0 +1,304 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package groupby_merger + +import ( + "context" + "database/sql" + "reflect" + "sync" + _ "unsafe" + + "github.com/ecodeclub/eorm/internal/merger/utils" + "go.uber.org/multierr" + + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" + "github.com/gotomicro/ekit/mapx" +) + +type AggregatorMerger struct { + aggregators []aggregator.Aggregator + groupColumns []merger.ColumnInfo + columnsName []string +} + +func NewAggregatorMerger(aggregators []aggregator.Aggregator, groupColumns []merger.ColumnInfo) *AggregatorMerger { + cols := make([]string, 0, len(aggregators)+len(groupColumns)) + for _, groubyCol := range groupColumns { + cols = append(cols, groubyCol.Name) + } + for _, agg := range aggregators { + cols = append(cols, agg.ColumnName()) + } + + return &AggregatorMerger{ + aggregators: aggregators, + groupColumns: groupColumns, + columnsName: cols, + } +} + +// Merge 该实现会全部拿取results里面的数据,由于sql.Rows数据拿完之后会自动关闭,所以这边隐式的关闭了所有的sql.Rows +func (a *AggregatorMerger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + if len(results) == 0 { + return nil, errs.ErrMergerEmptyRows + } + + for _, res := range results { + err := a.checkColumns(res) + if err != nil { + return nil, err + } + } + dataMap, dataIndex, err := a.getCols(results) + if err != nil { + return nil, err + } + + return &AggregatorRows{ + rowsList: results, + aggregators: a.aggregators, + groupColumns: a.groupColumns, + mu: &sync.RWMutex{}, + dataMap: dataMap, + dataIndex: dataIndex, + cur: -1, + cols: a.columnsName, + }, nil + +} +func (a *AggregatorMerger) checkColumns(rows *sql.Rows) error { + if rows == nil { + return errs.ErrMergerRowsIsNull + } + return nil +} + +func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][]any], []Key, error) { + treeMap, err := mapx.NewTreeMap[Key, [][]any](compareKey) + if err != nil { + return nil, nil, err + } + keys := make([]Key, 0, 16) + for _, res := range rowsList { + colsData, err := a.getCol(res) + if err != nil { + return nil, nil, err + } + for _, colData := range colsData { + key := Key{columnValues: make([]any, 0, len(a.groupColumns))} + for _, groupByCol := range a.groupColumns { + key.columnValues = append(key.columnValues, colData[groupByCol.Index]) + } + val, ok := treeMap.Get(key) + + if ok { + val = append(val, colData) + err = treeMap.Set(key, val) + if err != nil { + return nil, nil, err + } + } else { + keys = append(keys, key) + err := treeMap.Put(key, [][]any{colData}) + if err != nil { + return nil, nil, err + } + } + } + } + return treeMap, keys, nil +} + +func (a *AggregatorMerger) getCol(row *sql.Rows) ([][]any, error) { + ans := make([][]any, 0, 16) + for row.Next() { + colsData, err := utils.Scan(row) + if err != nil { + return nil, err + } + ans = append(ans, colsData) + } + if row.Err() != nil { + return nil, row.Err() + } + + return ans, nil + +} + +type AggregatorRows struct { + rowsList []*sql.Rows + aggregators []aggregator.Aggregator + groupColumns []merger.ColumnInfo + dataMap *mapx.TreeMap[Key, [][]any] + cur int + dataIndex []Key + mu *sync.RWMutex + curData []any + closed bool + lastErr error + cols []string +} + +// Next 返回列的顺序先分组信息然后是聚合函数信息 +func (a *AggregatorRows) Next() bool { + a.mu.Lock() + if a.closed { + a.mu.Unlock() + return false + } + a.cur++ + if a.cur >= len(a.dataIndex) { + a.mu.Unlock() + _ = a.Close() + return false + } + a.curData = make([]any, 0, len(a.aggregators)+len(a.groupColumns)) + + a.curData = append(a.curData, a.dataIndex[a.cur].columnValues...) + + for _, agg := range a.aggregators { + val, _ := a.dataMap.Get(a.dataIndex[a.cur]) + res, err := agg.Aggregate(val) + if err != nil { + a.lastErr = err + a.mu.Unlock() + _ = a.Close() + return false + } + a.curData = append(a.curData, res) + } + + a.mu.Unlock() + return true +} + +func (a *AggregatorRows) Scan(dest ...any) error { + a.mu.Lock() + defer a.mu.Unlock() + if a.lastErr != nil { + return a.lastErr + } + if a.closed { + return errs.ErrMergerRowsClosed + } + if a.cur == -1 { + return errs.ErrMergerScanNotNext + } + for i := 0; i < len(dest); i++ { + err := utils.ConvertAssign(dest[i], a.curData[i]) + if err != nil { + return err + } + } + return nil +} + +// Close 关闭所有的sql.Rows +func (a *AggregatorRows) Close() error { + a.mu.Lock() + defer a.mu.Unlock() + a.closed = true + errorList := make([]error, 0, len(a.rowsList)) + for i := 0; i < len(a.rowsList); i++ { + row := a.rowsList[i] + err := row.Close() + if err != nil { + errorList = append(errorList, err) + } + } + return multierr.Combine(errorList...) +} + +// Columns 返回列的顺序先分组信息然后是聚合函数信息 +func (a *AggregatorRows) Columns() ([]string, error) { + a.mu.RLock() + defer a.mu.RUnlock() + if a.closed { + return nil, errs.ErrMergerRowsClosed + } + return a.cols, nil +} + +func (a *AggregatorRows) Err() error { + a.mu.RLock() + defer a.mu.RUnlock() + return a.lastErr +} + +type Key struct { + columnValues []any +} + +func compareKey(a, b Key) int { + keyLen := len(a.columnValues) + for i := 0; i < keyLen; i++ { + compareFunc := compareFuncMapping[reflect.TypeOf(a.columnValues[i]).Kind()] + res := compareFunc(a.columnValues[i], b.columnValues[i]) + if res != 0 { + return res + } + } + return 0 +} + +func compare[T Ordered](ii any, jj any) int { + i, j := ii.(T), jj.(T) + if i < j { + return -1 + } else if i > j { + return 1 + } else { + return 0 + } +} + +type Ordered interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | ~string +} + +var compareFuncMapping = map[reflect.Kind]func(any, any) int{ + reflect.Int: compare[int], + reflect.Int8: compare[int8], + reflect.Int16: compare[int16], + reflect.Int32: compare[int32], + reflect.Int64: compare[int64], + reflect.Uint8: compare[uint8], + reflect.Uint16: compare[uint16], + reflect.Uint32: compare[uint32], + reflect.Uint64: compare[uint64], + reflect.Float32: compare[float32], + reflect.Float64: compare[float64], + reflect.String: compare[string], + reflect.Uint: compare[uint], + reflect.Bool: compareBool, +} + +func compareBool(ii, jj any) int { + i, j := ii.(bool), jj.(bool) + if i == j { + return 0 + } + if i && !j { + return 1 + } + return -1 +} diff --git a/internal/merger/groupby_merger/aggregator_merger_test.go b/internal/merger/groupby_merger/aggregator_merger_test.go new file mode 100644 index 0000000..1f63bd3 --- /dev/null +++ b/internal/merger/groupby_merger/aggregator_merger_test.go @@ -0,0 +1,596 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package groupby_merger + +import ( + "context" + "database/sql" + "errors" + "testing" + + "github.com/ecodeclub/eorm/internal/merger" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +var ( + nextMockErr error = errors.New("rows: MockNextErr") + aggregatorErr error = errors.New("aggregator: MockAggregatorErr") +) + +type MergerSuite struct { + suite.Suite + mockDB01 *sql.DB + mock01 sqlmock.Sqlmock + mockDB02 *sql.DB + mock02 sqlmock.Sqlmock + mockDB03 *sql.DB + mock03 sqlmock.Sqlmock + mockDB04 *sql.DB + mock04 sqlmock.Sqlmock +} + +func (ms *MergerSuite) SetupTest() { + t := ms.T() + ms.initMock(t) +} + +func (ms *MergerSuite) TearDownTest() { + _ = ms.mockDB01.Close() + _ = ms.mockDB02.Close() + _ = ms.mockDB03.Close() + _ = ms.mockDB04.Close() +} + +func (ms *MergerSuite) initMock(t *testing.T) { + var err error + ms.mockDB01, ms.mock01, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + ms.mockDB02, ms.mock02, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + ms.mockDB03, ms.mock03, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + ms.mockDB04, ms.mock04, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } +} + +func TestMerger(t *testing.T) { + suite.Run(t, &MergerSuite{}) +} + +func (ms *MergerSuite) TestAggregatorMerger_Merge() { + testcases := []struct { + name string + aggregators []aggregator.Aggregator + rowsList []*sql.Rows + GroupByColumns []merger.ColumnInfo + wantErr error + ctx func() (context.Context, context.CancelFunc) + }{ + { + name: "正常案例", + aggregators: []aggregator.Aggregator{ + aggregator.NewCount(merger.NewColumnInfo(2, "id")), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "county"), + merger.NewColumnInfo(1, "gender"), + }, + rowsList: func() []*sql.Rows { + query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" + cols := []string{"county", "gender", "SUM(id)"} + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }(), + + ctx: func() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return ctx, cancel + }, + }, + { + name: "超时", + aggregators: []aggregator.Aggregator{ + aggregator.NewCount(merger.NewColumnInfo(1, "id")), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), + }, + rowsList: func() []*sql.Rows { + query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" + cols := []string{"user_name", "SUM(id)"} + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }(), + ctx: func() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), 0) + return ctx, cancel + }, + wantErr: context.DeadlineExceeded, + }, + { + name: "rowsList为空", + aggregators: []aggregator.Aggregator{ + aggregator.NewCount(merger.NewColumnInfo(1, "id")), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), + }, + rowsList: func() []*sql.Rows { + return []*sql.Rows{} + }(), + ctx: func() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return ctx, cancel + }, + wantErr: errs.ErrMergerEmptyRows, + }, + { + name: "rowsList中有nil", + aggregators: []aggregator.Aggregator{ + aggregator.NewCount(merger.NewColumnInfo(1, "id")), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), + }, + rowsList: func() []*sql.Rows { + return []*sql.Rows{nil} + }(), + ctx: func() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return ctx, cancel + }, + wantErr: errs.ErrMergerRowsIsNull, + }, + { + name: "rowsList中有sql.Rows返回错误", + aggregators: []aggregator.Aggregator{ + aggregator.NewCount(merger.NewColumnInfo(1, "id")), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), + }, + rowsList: func() []*sql.Rows { + query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" + cols := []string{"user_name", "SUM(id)"} + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20).RowError(1, nextMockErr)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }(), + ctx: func() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return ctx, cancel + }, + wantErr: nextMockErr, + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) + ctx, cancel := tc.ctx() + groupByRows, err := merger.Merge(ctx, tc.rowsList) + cancel() + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + require.NotNil(t, groupByRows) + }) + } +} + +func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { + testcases := []struct { + name string + aggregators []aggregator.Aggregator + rowsList []*sql.Rows + wantVal [][]any + gotVal [][]any + GroupByColumns []merger.ColumnInfo + wantErr error + }{ + { + name: "同一组数据在不同的sql.Rows中", + aggregators: []aggregator.Aggregator{ + aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), + }, + rowsList: func() []*sql.Rows { + query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" + cols := []string{"user_name", "SUM(id)"} + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("dm", 20)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("zwl", 20)) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }(), + wantVal: [][]any{ + {"zwl", int64(30)}, + {"dm", int64(40)}, + {"xz", int64(10)}, + }, + gotVal: [][]any{ + {"", int64(0)}, + {"", int64(0)}, + {"", int64(0)}, + }, + }, + { + name: "同一组数据在同一个sql.Rows中", + aggregators: []aggregator.Aggregator{ + aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), + }, + rowsList: func() []*sql.Rows { + query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" + cols := []string{"user_name", "SUM(id)"} + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 10).AddRow("xm", 20)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xz", 10).AddRow("xx", 20)) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("dm", 20)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }(), + wantVal: [][]any{ + {"zwl", int64(10)}, + {"xm", int64(20)}, + {"xz", int64(10)}, + {"xx", int64(20)}, + {"dm", int64(20)}, + }, + gotVal: [][]any{ + {"", int64(0)}, + {"", int64(0)}, + {"", int64(0)}, + {"", int64(0)}, + {"", int64(0)}, + }, + }, + { + name: "多个分组列", + aggregators: []aggregator.Aggregator{ + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "county"), + merger.NewColumnInfo(1, "gender"), + }, + rowsList: func() []*sql.Rows { + query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" + cols := []string{"county", "gender", "SUM(id)"} + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10).AddRow("hangzhou", "female", 20).AddRow("shanghai", "female", 30)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40).AddRow("shanghai", "female", 50).AddRow("hangzhou", "female", 60)) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70).AddRow("shanghai", "female", 80)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }(), + wantVal: [][]any{ + { + "hangzhou", + "male", + int64(10), + }, + { + "hangzhou", + "female", + int64(80), + }, + { + "shanghai", + "female", + int64(160), + }, + { + "shanghai", + "male", + int64(110), + }, + }, + gotVal: [][]any{ + {"", "", int64(0)}, + {"", "", int64(0)}, + {"", "", int64(0)}, + {"", "", int64(0)}, + }, + }, + { + name: "多个聚合函数", + aggregators: []aggregator.Aggregator{ + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), + aggregator.NewAVG(merger.NewColumnInfo(3, "SUM(age)"), merger.NewColumnInfo(4, "COUNT(age)"), "AVG(age)"), + }, + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "county"), + merger.NewColumnInfo(1, "gender"), + }, + + rowsList: func() []*sql.Rows { + query := "SELECT `county`,`gender`,SUM(`id`),SUM(`age`),COUNT(`age`) FROM `t1` GROUP BY `country`,`gender`" + cols := []string{"county", "gender", "SUM(id)", "SUM(age)", "COUNT(age)"} + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("hangzhou", "male", 10, 100, 2).AddRow("hangzhou", "female", 20, 120, 3).AddRow("shanghai", "female", 30, 90, 3)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 40, 120, 5).AddRow("shanghai", "female", 50, 120, 4).AddRow("hangzhou", "female", 60, 150, 3)) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("shanghai", "male", 70, 100, 5).AddRow("shanghai", "female", 80, 150, 5)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }(), + wantVal: [][]any{ + { + "hangzhou", + "male", + int64(10), + float64(50), + }, + { + "hangzhou", + "female", + int64(80), + float64(45), + }, + { + "shanghai", + "female", + int64(160), + float64(30), + }, + { + "shanghai", + "male", + int64(110), + float64(22), + }, + }, + gotVal: [][]any{ + {"", "", int64(0), float64(0)}, + {"", "", int64(0), float64(0)}, + {"", "", int64(0), float64(0)}, + {"", "", int64(0), float64(0)}, + }, + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) + groupByRows, err := merger.Merge(context.Background(), tc.rowsList) + require.NoError(t, err) + + idx := 0 + for groupByRows.Next() { + if idx >= len(tc.gotVal) { + break + } + tmp := make([]any, 0, len(tc.gotVal[0])) + for i := range tc.gotVal[idx] { + tmp = append(tmp, &tc.gotVal[idx][i]) + } + err := groupByRows.Scan(tmp...) + require.NoError(t, err) + idx++ + } + require.NoError(t, groupByRows.Err()) + assert.Equal(t, tc.wantVal, tc.gotVal) + }) + } +} + +func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() { + ms.T().Run("未调用Next,直接Scan,返回错", func(t *testing.T) { + cols := []string{"userid", "SUM(id)"} + query := "SELECT userid,SUM(id) FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) + r, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(t, err) + rowsList := []*sql.Rows{r} + merger := NewAggregatorMerger([]aggregator.Aggregator{aggregator.NewSum(merger.NewColumnInfo(1, "SUM(id)"))}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) + rows, err := merger.Merge(context.Background(), rowsList) + require.NoError(t, err) + userid := 0 + sumId := 0 + err = rows.Scan(&userid, &sumId) + assert.Equal(t, errs.ErrMergerScanNotNext, err) + }) + ms.T().Run("迭代过程中发现错误,调用Scan,返回迭代中发现的错误", func(t *testing.T) { + cols := []string{"userid", "SUM(id)"} + query := "SELECT userid,SUM(id) FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 10).AddRow(5, 20)) + r, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(t, err) + rowsList := []*sql.Rows{r} + merger := NewAggregatorMerger([]aggregator.Aggregator{&mockAggregate{}}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) + rows, err := merger.Merge(context.Background(), rowsList) + require.NoError(t, err) + userid := 0 + sumId := 0 + rows.Next() + err = rows.Scan(&userid, &sumId) + assert.Equal(t, aggregatorErr, err) + }) + +} + +func (ms *MergerSuite) TestAggregatorRows_NextAndErr() { + testcases := []struct { + name string + rowsList func() []*sql.Rows + wantErr error + aggregators []aggregator.Aggregator + GroupByColumns []merger.ColumnInfo + }{ + { + name: "有一个aggregator返回error", + rowsList: func() []*sql.Rows { + cols := []string{"username", "COUNT(id)"} + query := "SELECT username,COUNT(`id`) FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("zwl", 1)) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("daming", 2)) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("wu", 4)) + ms.mock04.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("ming", 5)) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03, ms.mockDB04} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + return rowsList + }, + aggregators: func() []aggregator.Aggregator { + return []aggregator.Aggregator{ + &mockAggregate{}, + } + }(), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "username"), + }, + wantErr: aggregatorErr, + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + merger := NewAggregatorMerger(tc.aggregators, tc.GroupByColumns) + rows, err := merger.Merge(context.Background(), tc.rowsList()) + require.NoError(t, err) + for rows.Next() { + } + count := int64(0) + name := "" + err = rows.Scan(&name, &count) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantErr, rows.Err()) + }) + } +} + +func (ms *MergerSuite) TestAggregatorRows_Columns() { + cols := []string{"userid", "SUM(grade)", "COUNT(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} + query := "SELECT SUM(`grade`),COUNT(`grade`),SUM(`id`),MIN(`id`),MAX(`id`),COUNT(`id`),`userid` FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1, 1, 2, 1, 3, 10, "zwl")) + ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11, "dm")) + ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12, "xm")) + aggregators := []aggregator.Aggregator{ + aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), + aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")), + aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")), + aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")), + } + groupbyColumns := []merger.ColumnInfo{ + merger.NewColumnInfo(6, "userid"), + } + merger := NewAggregatorMerger(aggregators, groupbyColumns) + dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} + rowsList := make([]*sql.Rows, 0, len(dbs)) + for _, db := range dbs { + row, err := db.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + rowsList = append(rowsList, row) + } + + rows, err := merger.Merge(context.Background(), rowsList) + require.NoError(ms.T(), err) + wantCols := []string{"userid", "AVG(grade)", "SUM(id)", "MIN(id)", "MAX(id)", "COUNT(id)"} + ms.T().Run("Next没有迭代完", func(t *testing.T) { + for rows.Next() { + columns, err := rows.Columns() + require.NoError(t, err) + assert.Equal(t, wantCols, columns) + } + require.NoError(t, rows.Err()) + }) + ms.T().Run("Next迭代完", func(t *testing.T) { + require.False(t, rows.Next()) + require.NoError(t, rows.Err()) + _, err := rows.Columns() + assert.Equal(t, errs.ErrMergerRowsClosed, err) + }) +} + +type mockAggregate struct { + cols [][]any +} + +func (m *mockAggregate) Aggregate(cols [][]any) (any, error) { + m.cols = cols + return nil, aggregatorErr +} + +func (m *mockAggregate) ColumnName() string { + return "mockAggregate" +} diff --git a/internal/merger/sortmerger/merger.go b/internal/merger/sortmerger/merger.go index 5e2497f..96b2295 100644 --- a/internal/merger/sortmerger/merger.go +++ b/internal/merger/sortmerger/merger.go @@ -20,7 +20,8 @@ import ( "database/sql" "reflect" "sync" - _ "unsafe" + + "github.com/ecodeclub/eorm/internal/merger/utils" "go.uber.org/multierr" @@ -41,9 +42,6 @@ type Ordered interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | ~string } -//go:linkname convertAssign database/sql.convertAssign -func convertAssign(dest, src any) error - type SortColumn struct { name string order Order @@ -285,7 +283,7 @@ func (r *Rows) Scan(dest ...any) error { } for i := 0; i < len(dest); i++ { - err := convertAssign(dest[i], r.cur.columns[i]) + err := utils.ConvertAssign(dest[i], r.cur.columns[i]) if err != nil { return err } diff --git a/internal/merger/type.go b/internal/merger/type.go index 7c37fe8..32955fb 100644 --- a/internal/merger/type.go +++ b/internal/merger/type.go @@ -33,3 +33,15 @@ type Rows interface { Columns() ([]string, error) Err() error } + +type ColumnInfo struct { + Index int + Name string +} + +func NewColumnInfo(index int, name string) ColumnInfo { + return ColumnInfo{ + Index: index, + Name: name, + } +} diff --git a/internal/merger/utils/convert_Assign.go b/internal/merger/utils/convert_Assign.go new file mode 100644 index 0000000..e4abfc3 --- /dev/null +++ b/internal/merger/utils/convert_Assign.go @@ -0,0 +1,23 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + _ "database/sql" + _ "unsafe" +) + +//go:linkname ConvertAssign database/sql.convertAssign +func ConvertAssign(dest, src any) error diff --git a/internal/merger/utils/scan.go b/internal/merger/utils/scan.go new file mode 100644 index 0000000..3412a90 --- /dev/null +++ b/internal/merger/utils/scan.go @@ -0,0 +1,48 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "database/sql" + "reflect" +) + +func Scan(row *sql.Rows) ([]any, error) { + colsInfo, err := row.ColumnTypes() + if err != nil { + return nil, err + } + colsData := make([]any, 0, len(colsInfo)) + // 拿到sql.Rows字段的类型然后初始化 + for _, colInfo := range colsInfo { + typ := colInfo.ScanType() + // sqlite3的驱动返回的是指针。循环的去除指针 + for typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + newData := reflect.New(typ).Interface() + colsData = append(colsData, newData) + } + // 通过Scan赋值 + err = row.Scan(colsData...) + if err != nil { + return nil, err + } + // 去掉reflect.New的指针 + for i := 0; i < len(colsData); i++ { + colsData[i] = reflect.ValueOf(colsData[i]).Elem().Interface() + } + return colsData, nil +} diff --git a/internal/merger/utils/scan_test.go b/internal/merger/utils/scan_test.go new file mode 100644 index 0000000..61fc089 --- /dev/null +++ b/internal/merger/utils/scan_test.go @@ -0,0 +1,249 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ScanSuite struct { + suite.Suite + mockDB01 *sql.DB + mock01 sqlmock.Sqlmock + db02 *sql.DB +} + +func (ms *ScanSuite) SetupTest() { + t := ms.T() + ms.initMock(t) +} + +func (ms *ScanSuite) TearDownTest() { + _ = ms.mockDB01.Close() + _ = ms.db02.Close() +} +func (ms *ScanSuite) initMock(t *testing.T) { + var err error + query := "CREATE TABLE t1 (\n id int primary key,\n `int` int,\n `integer` integer,\n `tinyint` TINYINT,\n `smallint` smallint,\n `MEDIUMINT` MEDIUMINT,\n `BIGINT` BIGINT,\n `UNSIGNED_BIG_INT` UNSIGNED BIG INT,\n `INT2` INT2,\n `INT8` INT8,\n `VARCHAR` VARCHAR(20),\n \t\t`CHARACTER` CHARACTER(20),\n `VARYING_CHARACTER` VARYING_CHARACTER(20),\n `NCHAR` NCHAR(23),\n `TEXT` TEXT,\n `CLOB` CLOB,\n `REAL` REAL,\n `DOUBLE` DOUBLE,\n `DOUBLE_PRECISION` DOUBLE PRECISION,\n `FLOAT` FLOAT,\n `DATETIME` DATETIME \n );" + ms.mockDB01, ms.mock01, err = sqlmock.New() + if err != nil { + t.Fatal(err) + } + db02, err := sql.Open("sqlite3", "file:test01.db?cache=shared&mode=memory") + if err != nil { + t.Fatal(err) + } + ms.db02 = db02 + _, err = db02.ExecContext(context.Background(), query) + if err != nil { + t.Fatal(err) + } +} +func (ms *ScanSuite) TestScan() { + testcases := []struct { + name string + rows *sql.Rows + want []any + err error + afterFunc func() + }{ + { + name: "浮点数", + rows: func() *sql.Rows { + cols := []string{"float64"} + query := "SELECT float64 FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(float64(1.1))) + rows, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{float64(1.1)}, + }, + { + name: "int64", + rows: func() *sql.Rows { + cols := []string{"int64"} + query := "SELECT int64 FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(int64(1))) + rows, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{int64(1)}, + }, + { + name: "int32", + rows: func() *sql.Rows { + cols := []string{"int32"} + query := "SELECT int32 FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(int32(1))) + rows, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{int32(1)}, + }, + { + name: "int16", + rows: func() *sql.Rows { + cols := []string{"int16"} + query := "SELECT int16 FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(int16(1))) + rows, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{int16(1)}, + }, + { + name: "int8", + rows: func() *sql.Rows { + cols := []string{"int8"} + query := "SELECT int8 FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(int8(1))) + rows, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{int8(1)}, + }, + { + name: "int", + rows: func() *sql.Rows { + cols := []string{"int"} + query := "SELECT FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) + rows, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{1}, + }, + { + name: "string", + rows: func() *sql.Rows { + cols := []string{"string"} + query := "SELECT string FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("xx")) + rows, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{"string"}, + }, + { + name: "bool", + rows: func() *sql.Rows { + cols := []string{"bool"} + query := "SELECT bool FROM `t1`" + ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(true)) + rows, err := ms.mockDB01.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{true}, + }, + { + name: "sqlite3 int类型", + rows: func() *sql.Rows { + _, err := ms.db02.Exec("INSERT INTO `t1` (`int`,`integer`,`tinyint`,`smallint`,`MEDIUMINT`,`BIGINT`,`UNSIGNED_BIG_INT`,`INT2`) VALUES (1,1,1,1,1,1,1,1);") + require.NoError(ms.T(), err) + query := "SELECT `int`,`integer`,`tinyint`,`smallint`,`MEDIUMINT`,`BIGINT`,`UNSIGNED_BIG_INT`,`INT2`,`INT8` FROM `t1`;" + rows, err := ms.db02.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: true, Int64: 1}, sql.NullInt64{Valid: false, Int64: 0}}, + afterFunc: func() { + _, err := ms.db02.Exec("truncate table `t1`") + require.NoError(ms.T(), err) + }, + }, + { + name: "sqlite3 string类型", + rows: func() *sql.Rows { + _, err := ms.db02.Exec("INSERT INTO `t1` (`VARCHAR`,`CHARACTER`,`VARYING_CHARACTER`,`NCHAR`,`TEXT`) VALUES ('zwl','zwl','zwl','zwl','zwl');") + require.NoError(ms.T(), err) + query := "SELECT `VARCHAR`,`CHARACTER`,`VARYING_CHARACTER`,`NCHAR`,`TEXT`,`CLOB` FROM `t1`;" + rows, err := ms.db02.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: true, String: "zwl"}, sql.NullString{Valid: false, String: ""}}, + afterFunc: func() { + _, err := ms.db02.Exec("truncate table `t1`") + require.NoError(ms.T(), err) + }, + }, + { + name: "sqlite3 浮点类型", + rows: func() *sql.Rows { + _, err := ms.db02.Exec("INSERT INTO `t1` (`REAL`,`DOUBLE`,`DOUBLE_PRECISION`) VALUES (1.0,1.0,1.0);") + require.NoError(ms.T(), err) + query := "SELECT `REAL`,`DOUBLE`,`DOUBLE_PRECISION`,`FLOAT` FROM `t1`;" + rows, err := ms.db02.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: true, Float64: 1.0}, sql.NullFloat64{Valid: false, Float64: 0}}, + afterFunc: func() { + _, err := ms.db02.Exec("truncate table `t1`") + require.NoError(ms.T(), err) + }, + }, + { + name: "sqlite3时间类型", + rows: func() *sql.Rows { + _, err := ms.db02.Exec("INSERT INTO `t1` (`DATETIME`) VALUES ('2022-01-01 12:00:00');") + require.NoError(ms.T(), err) + query := "SELECT `DATETIME` FROM `t1`;" + rows, err := ms.db02.QueryContext(context.Background(), query) + require.NoError(ms.T(), err) + return rows + }(), + want: []any{sql.NullTime{Valid: true, Time: func() time.Time { + t, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local) + require.NoError(ms.T(), err) + return t + + }()}}, + }, + } + for _, tc := range testcases { + ms.T().Run(tc.name, func(t *testing.T) { + rows := tc.rows + require.True(t, rows.Next()) + got, err := Scan(rows) + require.Equal(t, tc.err, err) + if err == nil { + return + } + require.Equal(t, tc.want, got) + tc.afterFunc() + }) + } +} + +func TestMerger(t *testing.T) { + suite.Run(t, &ScanSuite{}) +}