diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 83b4853..376b037 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -25,6 +25,7 @@ - [eorm: 分库分表: 结果集处理--聚合函数(不含GroupBy子句)](https://github.com/ecodeclub/eorm/pull/187) - [eorm: 分库分表: 范围查询支持](https://github.com/ecodeclub/eorm/pull/178) - [eorm: 修复单条查询时连接泄露问题](https://github.com/ecodeclub/eorm/pull/188) +- [eorm: 分库分表: 结果集处理--聚合函数(含GroupBy子句)](https://github.com/ecodeclub/eorm/pull/193) ## v0.0.1: - [Init Project](https://github.com/ecodeclub/eorm/pull/1) diff --git a/internal/merger/aggregatemerger/aggregator/avg.go b/internal/merger/aggregatemerger/aggregator/avg.go index c1de207..d6aabf0 100644 --- a/internal/merger/aggregatemerger/aggregator/avg.go +++ b/internal/merger/aggregatemerger/aggregator/avg.go @@ -17,19 +17,21 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) // AVG 用于求平均值,通过sum/count求得。 // AVG 我们并不能预期在不同的数据库上,精度会不会损失,以及损失的话会有多少的损失。这很大程度上跟数据库类型,数据库驱动实现都有关 type AVG struct { - sumColumnInfo ColumnInfo - countColumnInfo ColumnInfo + sumColumnInfo merger.ColumnInfo + countColumnInfo merger.ColumnInfo avgName string } // NewAVG sumInfo是sum的信息,countInfo是count的信息,avgName用于Column方法 -func NewAVG(sumInfo ColumnInfo, countInfo ColumnInfo, avgName string) *AVG { +func NewAVG(sumInfo merger.ColumnInfo, countInfo merger.ColumnInfo, avgName string) *AVG { return &AVG{ sumColumnInfo: sumInfo, countColumnInfo: countInfo, diff --git a/internal/merger/aggregatemerger/aggregator/avg_test.go b/internal/merger/aggregatemerger/aggregator/avg_test.go index 23f67eb..1d8b042 100644 --- a/internal/merger/aggregatemerger/aggregator/avg_test.go +++ b/internal/merger/aggregatemerger/aggregator/avg_test.go @@ -15,6 +15,7 @@ package aggregator import ( + "github.com/ecodeclub/eorm/internal/merger" "testing" "github.com/ecodeclub/eorm/internal/merger/internal/errs" @@ -82,7 +83,7 @@ func TestAvg_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - avg := NewAVG(NewColumnInfo(tc.index[0], "SUM(grade)"), NewColumnInfo(tc.index[1], "COUNT(grade)"), "AVG(grade)") + avg := NewAVG(merger.NewColumnInfo(tc.index[0], "SUM(grade)"), merger.NewColumnInfo(tc.index[1], "COUNT(grade)"), "AVG(grade)") val, err := avg.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/count.go b/internal/merger/aggregatemerger/aggregator/count.go index 6d08924..908837e 100644 --- a/internal/merger/aggregatemerger/aggregator/count.go +++ b/internal/merger/aggregatemerger/aggregator/count.go @@ -17,11 +17,13 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) type Count struct { - countInfo ColumnInfo + countInfo merger.ColumnInfo } func (s *Count) Aggregate(cols [][]any) (any, error) { @@ -49,7 +51,7 @@ func (s *Count) ColumnName() string { return s.countInfo.Name } -func NewCount(info ColumnInfo) *Count { +func NewCount(info merger.ColumnInfo) *Count { return &Count{ countInfo: info, } diff --git a/internal/merger/aggregatemerger/aggregator/count_test.go b/internal/merger/aggregatemerger/aggregator/count_test.go index 136d49f..81583c3 100644 --- a/internal/merger/aggregatemerger/aggregator/count_test.go +++ b/internal/merger/aggregatemerger/aggregator/count_test.go @@ -17,6 +17,8 @@ package aggregator import ( "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" "github.com/stretchr/testify/assert" @@ -78,7 +80,7 @@ func TestCount_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - count := NewCount(NewColumnInfo(tc.countIndex, "COUNT(id)")) + count := NewCount(merger.NewColumnInfo(tc.countIndex, "COUNT(id)")) val, err := count.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/max.go b/internal/merger/aggregatemerger/aggregator/max.go index 8f6464e..a8757b1 100644 --- a/internal/merger/aggregatemerger/aggregator/max.go +++ b/internal/merger/aggregatemerger/aggregator/max.go @@ -17,11 +17,13 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) type Max struct { - maxColumnInfo ColumnInfo + maxColumnInfo merger.ColumnInfo } func (m *Max) Aggregate(cols [][]any) (any, error) { @@ -49,7 +51,7 @@ func (m *Max) ColumnName() string { return m.maxColumnInfo.Name } -func NewMax(info ColumnInfo) *Max { +func NewMax(info merger.ColumnInfo) *Max { return &Max{ maxColumnInfo: info, } diff --git a/internal/merger/aggregatemerger/aggregator/max_test.go b/internal/merger/aggregatemerger/aggregator/max_test.go index f74dbb7..54071e4 100644 --- a/internal/merger/aggregatemerger/aggregator/max_test.go +++ b/internal/merger/aggregatemerger/aggregator/max_test.go @@ -17,6 +17,8 @@ package aggregator import ( "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" "github.com/stretchr/testify/assert" ) @@ -77,7 +79,7 @@ func TestMax_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - max := NewMax(NewColumnInfo(tc.maxIndex, "MAX(id)")) + max := NewMax(merger.NewColumnInfo(tc.maxIndex, "MAX(id)")) val, err := max.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/min.go b/internal/merger/aggregatemerger/aggregator/min.go index 61bde7a..321a62f 100644 --- a/internal/merger/aggregatemerger/aggregator/min.go +++ b/internal/merger/aggregatemerger/aggregator/min.go @@ -17,11 +17,13 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) type Min struct { - minColumnInfo ColumnInfo + minColumnInfo merger.ColumnInfo } func (m *Min) Aggregate(cols [][]any) (any, error) { @@ -50,7 +52,7 @@ func (m *Min) ColumnName() string { return m.minColumnInfo.Name } -func NewMin(info ColumnInfo) *Min { +func NewMin(info merger.ColumnInfo) *Min { return &Min{ minColumnInfo: info, } diff --git a/internal/merger/aggregatemerger/aggregator/min_test.go b/internal/merger/aggregatemerger/aggregator/min_test.go index 52f40f6..2b21962 100644 --- a/internal/merger/aggregatemerger/aggregator/min_test.go +++ b/internal/merger/aggregatemerger/aggregator/min_test.go @@ -15,6 +15,7 @@ package aggregator import ( + "github.com/ecodeclub/eorm/internal/merger" "testing" "github.com/ecodeclub/eorm/internal/merger/internal/errs" @@ -77,7 +78,7 @@ func TestMin_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - min := NewMin(NewColumnInfo(tc.minIndex, "MIN(id)")) + min := NewMin(merger.NewColumnInfo(tc.minIndex, "MIN(id)")) val, err := min.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/sum.go b/internal/merger/aggregatemerger/aggregator/sum.go index 768494c..b67f122 100644 --- a/internal/merger/aggregatemerger/aggregator/sum.go +++ b/internal/merger/aggregatemerger/aggregator/sum.go @@ -17,11 +17,13 @@ package aggregator import ( "reflect" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/ecodeclub/eorm/internal/merger/internal/errs" ) type Sum struct { - sumColumnInfo ColumnInfo + sumColumnInfo merger.ColumnInfo } func (s *Sum) Aggregate(cols [][]any) (any, error) { @@ -50,7 +52,7 @@ func (s *Sum) ColumnName() string { return s.sumColumnInfo.Name } -func NewSum(info ColumnInfo) *Sum { +func NewSum(info merger.ColumnInfo) *Sum { return &Sum{ sumColumnInfo: info, } diff --git a/internal/merger/aggregatemerger/aggregator/sum_test.go b/internal/merger/aggregatemerger/aggregator/sum_test.go index efa22cf..d2d0fb7 100644 --- a/internal/merger/aggregatemerger/aggregator/sum_test.go +++ b/internal/merger/aggregatemerger/aggregator/sum_test.go @@ -15,6 +15,7 @@ package aggregator import ( + "github.com/ecodeclub/eorm/internal/merger" "testing" "github.com/ecodeclub/eorm/internal/merger/internal/errs" @@ -79,7 +80,7 @@ func TestSum_Aggregate(t *testing.T) { } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - sum := NewSum(NewColumnInfo(tc.sumIndex, "SUM(id)")) + sum := NewSum(merger.NewColumnInfo(tc.sumIndex, "SUM(id)")) val, err := sum.Aggregate(tc.input) assert.Equal(t, tc.wantErr, err) if err != nil { diff --git a/internal/merger/aggregatemerger/aggregator/type.go b/internal/merger/aggregatemerger/aggregator/type.go index 6fb58cd..8a29c1e 100644 --- a/internal/merger/aggregatemerger/aggregator/type.go +++ b/internal/merger/aggregatemerger/aggregator/type.go @@ -24,15 +24,3 @@ type Aggregator interface { // ColumnName 聚合函数的别名 ColumnName() string } - -type ColumnInfo struct { - Index int - Name string -} - -func NewColumnInfo(index int, name string) ColumnInfo { - return ColumnInfo{ - Index: index, - Name: name, - } -} diff --git a/internal/merger/aggregatemerger/merger.go b/internal/merger/aggregatemerger/merger.go index b511283..e1d6e94 100644 --- a/internal/merger/aggregatemerger/merger.go +++ b/internal/merger/aggregatemerger/merger.go @@ -17,7 +17,7 @@ package aggregatemerger import ( "context" "database/sql" - "reflect" + "github.com/ecodeclub/eorm/internal/merger/utils" "sync" _ "unsafe" @@ -27,9 +27,6 @@ import ( "go.uber.org/multierr" ) -//go:linkname convertAssign database/sql.convertAssign -func convertAssign(dest, src any) error - // Merger 该实现不支持group by操作,并且聚合函数查询应该只返回一行数据。 type Merger struct { aggregators []aggregator.Aggregator @@ -151,32 +148,14 @@ func (r *Rows) getSqlRowsData() ([][]any, error) { return rowsData, nil } func (r *Rows) getSqlRowData(row *sql.Rows) ([]any, error) { - colsInfo, err := row.ColumnTypes() - if err != nil { - return nil, err - } - // colsData 表示一个sql.Rows的数据 - colsData := make([]any, 0, len(colsInfo)) + + var colsData []any + var err error if row.Next() { - // 拿到sql.Rows字段的类型然后初始化 - for _, colInfo := range colsInfo { - typ := colInfo.ScanType() - // sqlite3的驱动返回的是指针。循环的去除指针 - for typ.Kind() == reflect.Pointer { - typ = typ.Elem() - } - newData := reflect.New(typ).Interface() - colsData = append(colsData, newData) - } - // 通过Scan赋值 - err = row.Scan(colsData...) + colsData, err = utils.Scan(row) if err != nil { return nil, err } - // 去掉reflect.New的指针 - for i := 0; i < len(colsData); i++ { - colsData[i] = reflect.ValueOf(colsData[i]).Elem().Interface() - } } else { // sql.Rows迭代过程中发生报错,返回报错 if row.Err() != nil { @@ -201,7 +180,7 @@ func (r *Rows) Scan(dest ...any) error { return errs.ErrMergerScanNotNext } for i := 0; i < len(dest); i++ { - err := convertAssign(dest[i], r.cur[i]) + err := utils.ConvertAssign(dest[i], r.cur[i]) if err != nil { return err } diff --git a/internal/merger/aggregatemerger/merger_test.go b/internal/merger/aggregatemerger/merger_test.go index 8407056..ea12b12 100644 --- a/internal/merger/aggregatemerger/merger_test.go +++ b/internal/merger/aggregatemerger/merger_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/DATA-DOG/go-sqlmock" "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" "github.com/ecodeclub/eorm/internal/merger/internal/errs" @@ -141,7 +143,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -170,7 +172,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -200,7 +202,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewMax(aggregator.NewColumnInfo(0, "MAX(id)")), + aggregator.NewMax(merger.NewColumnInfo(0, "MAX(id)")), } }, }, @@ -229,7 +231,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewMin(aggregator.NewColumnInfo(0, "MIN(id)")), + aggregator.NewMin(merger.NewColumnInfo(0, "MIN(id)")), } }, }, @@ -258,7 +260,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(0, "COUNT(id)")), + aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), } }, }, @@ -289,7 +291,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewAVG(aggregator.NewColumnInfo(0, "SUM(grade)"), aggregator.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), } }, }, @@ -323,11 +325,11 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(0, "COUNT(id)")), - aggregator.NewMax(aggregator.NewColumnInfo(1, "MAX(id)")), - aggregator.NewMin(aggregator.NewColumnInfo(2, "MIN(id)")), - aggregator.NewSum(aggregator.NewColumnInfo(3, "SUM(id)")), - aggregator.NewAVG(aggregator.NewColumnInfo(4, "SUM(grade)"), aggregator.NewColumnInfo(5, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), + aggregator.NewMax(merger.NewColumnInfo(1, "MAX(id)")), + aggregator.NewMin(merger.NewColumnInfo(2, "MIN(id)")), + aggregator.NewSum(merger.NewColumnInfo(3, "SUM(id)")), + aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(5, "COUNT(grade)"), "AVG(grade)"), } }, }, @@ -360,13 +362,13 @@ func (ms *MergerSuite) TestRows_NextAndScan() { }(), aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewAVG(aggregator.NewColumnInfo(0, "SUM(grade)"), aggregator.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewSum(aggregator.NewColumnInfo(2, "SUM(grade)")), - aggregator.NewAVG(aggregator.NewColumnInfo(4, "SUM(grade)"), aggregator.NewColumnInfo(3, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewMin(aggregator.NewColumnInfo(5, "MIN(id)")), - aggregator.NewMin(aggregator.NewColumnInfo(6, "MIN(userid)")), - aggregator.NewMax(aggregator.NewColumnInfo(7, "MAX(id)")), - aggregator.NewCount(aggregator.NewColumnInfo(8, "COUNT(id)")), + aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(grade)")), + aggregator.NewAVG(merger.NewColumnInfo(4, "SUM(grade)"), merger.NewColumnInfo(3, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewMin(merger.NewColumnInfo(5, "MIN(id)")), + aggregator.NewMin(merger.NewColumnInfo(6, "MIN(userid)")), + aggregator.NewMax(merger.NewColumnInfo(7, "MAX(id)")), + aggregator.NewCount(merger.NewColumnInfo(8, "COUNT(id)")), } }, }, @@ -401,7 +403,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -433,7 +435,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -465,7 +467,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -491,7 +493,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() { wantErr: errs.ErrMergerAggregateHasEmptyRows, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")), } }, }, @@ -545,7 +547,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() { }, aggregators: func() []aggregator.Aggregator { return []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(0, "COUNT(id)")), + aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")), } }(), wantErr: nextMockErr, @@ -597,7 +599,7 @@ func (ms *MergerSuite) TestRows_Close() { ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(1)) ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2).CloseError(newCloseMockErr("db02"))) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3).CloseError(newCloseMockErr("db03"))) - merger := NewMerger(aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)"))) + merger := NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} rowsList := make([]*sql.Rows, 0, len(dbs)) for _, db := range dbs { @@ -646,11 +648,11 @@ func (ms *MergerSuite) TestRows_Columns() { ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11)) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12)) aggregators := []aggregator.Aggregator{ - aggregator.NewAVG(aggregator.NewColumnInfo(0, "SUM(grade)"), aggregator.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewSum(aggregator.NewColumnInfo(2, "SUM(id)")), - aggregator.NewMin(aggregator.NewColumnInfo(3, "MIN(id)")), - aggregator.NewMax(aggregator.NewColumnInfo(4, "MAX(id)")), - aggregator.NewCount(aggregator.NewColumnInfo(5, "COUNT(id)")), + aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), + aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")), + aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")), + aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")), } merger := NewMerger(aggregators...) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} @@ -691,7 +693,7 @@ func (ms *MergerSuite) TestMerger_Merge() { { name: "超时", merger: func() *Merger { - return NewMerger(aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithTimeout(context.Background(), 0) @@ -711,7 +713,7 @@ func (ms *MergerSuite) TestMerger_Merge() { { name: "sqlRows列表元素个数为0", merger: func() *Merger { - return NewMerger(aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) @@ -725,7 +727,7 @@ func (ms *MergerSuite) TestMerger_Merge() { { name: "sqlRows列表有nil", merger: func() *Merger { - return NewMerger(aggregator.NewSum(aggregator.NewColumnInfo(0, "SUM(id)"))) + return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)"))) }, ctx: func() (context.Context, context.CancelFunc) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/merger/groupbyMerger/aggregatorMerger.go b/internal/merger/groupby_merger/aggregator_merger.go similarity index 80% rename from internal/merger/groupbyMerger/aggregatorMerger.go rename to internal/merger/groupby_merger/aggregator_merger.go index 2ce7de6..1d60e87 100644 --- a/internal/merger/groupbyMerger/aggregatorMerger.go +++ b/internal/merger/groupby_merger/aggregator_merger.go @@ -1,4 +1,18 @@ -package groupbyMerger +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package groupby_merger import ( "context" @@ -7,34 +21,22 @@ import ( "sync" _ "unsafe" + "github.com/ecodeclub/eorm/internal/merger/utils" + "go.uber.org/multierr" + "github.com/ecodeclub/eorm/internal/merger" "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" "github.com/ecodeclub/eorm/internal/merger/internal/errs" "github.com/gotomicro/ekit/mapx" ) -//go:linkname convertAssign database/sql.convertAssign -func convertAssign(dest, src any) error - -type GroupByColumn struct { - Index int - Name string -} - -func NewGroupByColumn(index int, name string) GroupByColumn { - return GroupByColumn{ - Index: index, - Name: name, - } -} - type AggregatorMerger struct { aggregators []aggregator.Aggregator - groupColumns []GroupByColumn + groupColumns []merger.ColumnInfo columnsName []string } -func NewAggregatorMerger(aggregators []aggregator.Aggregator, groupColumns []GroupByColumn) *AggregatorMerger { +func NewAggregatorMerger(aggregators []aggregator.Aggregator, groupColumns []merger.ColumnInfo) *AggregatorMerger { cols := make([]string, 0, len(aggregators)+len(groupColumns)) for _, groubyCol := range groupColumns { cols = append(cols, groubyCol.Name) @@ -49,6 +51,8 @@ func NewAggregatorMerger(aggregators []aggregator.Aggregator, groupColumns []Gro columnsName: cols, } } + +// Merge 该实现会全部拿取results里面的数据,由于sql.Rows数据拿完之后会自动关闭,所以这边隐式的关闭了所有的sql.Rows func (a *AggregatorMerger) Merge(ctx context.Context, results []*sql.Rows) (merger.Rows, error) { if ctx.Err() != nil { return nil, ctx.Err() @@ -80,7 +84,7 @@ func (a *AggregatorMerger) Merge(ctx context.Context, results []*sql.Rows) (merg }, nil } -func (m *AggregatorMerger) checkColumns(rows *sql.Rows) error { +func (a *AggregatorMerger) checkColumns(rows *sql.Rows) error { if rows == nil { return errs.ErrMergerRowsIsNull } @@ -92,7 +96,7 @@ func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][ if err != nil { return nil, nil, err } - keys := make([]Key, 0) + keys := make([]Key, 0, 16) for _, res := range rowsList { colsData, err := a.getCol(res) if err != nil { @@ -124,33 +128,12 @@ func (a *AggregatorMerger) getCols(rowsList []*sql.Rows) (*mapx.TreeMap[Key, [][ } func (a *AggregatorMerger) getCol(row *sql.Rows) ([][]any, error) { - colsInfo, err := row.ColumnTypes() - if err != nil { - return nil, err - } - // colsData 表示一个sql.Rows的数据 - ans := make([][]any, 0) + ans := make([][]any, 0, 16) for row.Next() { - colsData := make([]any, 0, len(colsInfo)) - // 拿到sql.Rows字段的类型然后初始化 - for _, colInfo := range colsInfo { - typ := colInfo.ScanType() - // sqlite3的驱动返回的是指针。循环的去除指针 - for typ.Kind() == reflect.Pointer { - typ = typ.Elem() - } - newData := reflect.New(typ).Interface() - colsData = append(colsData, newData) - } - // 通过Scan赋值 - err = row.Scan(colsData...) + colsData, err := utils.Scan(row) if err != nil { return nil, err } - // 去掉reflect.New的指针 - for i := 0; i < len(colsData); i++ { - colsData[i] = reflect.ValueOf(colsData[i]).Elem().Interface() - } ans = append(ans, colsData) } if row.Err() != nil { @@ -164,7 +147,7 @@ func (a *AggregatorMerger) getCol(row *sql.Rows) ([][]any, error) { type AggregatorRows struct { rowsList []*sql.Rows aggregators []aggregator.Aggregator - groupColumns []GroupByColumn + groupColumns []merger.ColumnInfo dataMap *mapx.TreeMap[Key, [][]any] cur int dataIndex []Key @@ -221,7 +204,7 @@ func (a *AggregatorRows) Scan(dest ...any) error { return errs.ErrMergerScanNotNext } for i := 0; i < len(dest); i++ { - err := convertAssign(dest[i], a.curData[i]) + err := utils.ConvertAssign(dest[i], a.curData[i]) if err != nil { return err } @@ -234,7 +217,15 @@ func (a *AggregatorRows) Close() error { a.mu.Lock() defer a.mu.Unlock() a.closed = true - return nil + errorList := make([]error, 0, len(a.rowsList)) + for i := 0; i < len(a.rowsList); i++ { + row := a.rowsList[i] + err := row.Close() + if err != nil { + errorList = append(errorList, err) + } + } + return multierr.Combine(errorList...) } // Columns 返回列的顺序先分组信息然后是聚合函数信息 diff --git a/internal/merger/groupbyMerger/aggregatorMerger_test.go b/internal/merger/groupby_merger/aggregator_merger_test.go similarity index 85% rename from internal/merger/groupbyMerger/aggregatorMerger_test.go rename to internal/merger/groupby_merger/aggregator_merger_test.go index 75ffeeb..1f63bd3 100644 --- a/internal/merger/groupbyMerger/aggregatorMerger_test.go +++ b/internal/merger/groupby_merger/aggregator_merger_test.go @@ -1,4 +1,18 @@ -package groupbyMerger +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package groupby_merger import ( "context" @@ -6,6 +20,8 @@ import ( "errors" "testing" + "github.com/ecodeclub/eorm/internal/merger" + "github.com/DATA-DOG/go-sqlmock" "github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator" "github.com/ecodeclub/eorm/internal/merger/internal/errs" @@ -72,18 +88,18 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { name string aggregators []aggregator.Aggregator rowsList []*sql.Rows - GroupByColumns []GroupByColumn + GroupByColumns []merger.ColumnInfo wantErr error ctx func() (context.Context, context.CancelFunc) }{ { name: "正常案例", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(2, "id")), + aggregator.NewCount(merger.NewColumnInfo(2, "id")), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "county"), - NewGroupByColumn(1, "gender"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "county"), + merger.NewColumnInfo(1, "gender"), }, rowsList: func() []*sql.Rows { query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" @@ -109,10 +125,10 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "超时", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(1, "id")), + aggregator.NewCount(merger.NewColumnInfo(1, "id")), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "user_name"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), }, rowsList: func() []*sql.Rows { query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" @@ -138,10 +154,10 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "rowsList为空", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(1, "id")), + aggregator.NewCount(merger.NewColumnInfo(1, "id")), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "user_name"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), }, rowsList: func() []*sql.Rows { return []*sql.Rows{} @@ -155,10 +171,10 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "rowsList中有nil", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(1, "id")), + aggregator.NewCount(merger.NewColumnInfo(1, "id")), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "user_name"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), }, rowsList: func() []*sql.Rows { return []*sql.Rows{nil} @@ -172,10 +188,10 @@ func (ms *MergerSuite) TestAggregatorMerger_Merge() { { name: "rowsList中有sql.Rows返回错误", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(1, "id")), + aggregator.NewCount(merger.NewColumnInfo(1, "id")), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "user_name"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), }, rowsList: func() []*sql.Rows { query := "SELECT `user_name`,SUM(`id`) FROM `t1` GROUP BY `user_name`" @@ -221,16 +237,16 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { rowsList []*sql.Rows wantVal [][]any gotVal [][]any - GroupByColumns []GroupByColumn + GroupByColumns []merger.ColumnInfo wantErr error }{ { name: "同一组数据在不同的sql.Rows中", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(1, "COUNT(id)")), + aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "user_name"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), }, rowsList: func() []*sql.Rows { query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" @@ -261,10 +277,10 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { { name: "同一组数据在同一个sql.Rows中", aggregators: []aggregator.Aggregator{ - aggregator.NewCount(aggregator.NewColumnInfo(1, "COUNT(id)")), + aggregator.NewCount(merger.NewColumnInfo(1, "COUNT(id)")), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "user_name"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "user_name"), }, rowsList: func() []*sql.Rows { query := "SELECT `user_name`,COUNT(`id`) FROM `t1` GROUP BY `user_name`" @@ -299,11 +315,11 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { { name: "多个分组列", aggregators: []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(2, "SUM(id)")), + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "county"), - NewGroupByColumn(1, "gender"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "county"), + merger.NewColumnInfo(1, "gender"), }, rowsList: func() []*sql.Rows { query := "SELECT `county`,`gender`,SUM(`id`) FROM `t1` GROUP BY `country`,`gender`" @@ -352,12 +368,12 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndScan() { { name: "多个聚合函数", aggregators: []aggregator.Aggregator{ - aggregator.NewSum(aggregator.NewColumnInfo(2, "SUM(id)")), - aggregator.NewAVG(aggregator.NewColumnInfo(3, "SUM(age)"), aggregator.NewColumnInfo(4, "COUNT(age)"), "AVG(age)"), + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), + aggregator.NewAVG(merger.NewColumnInfo(3, "SUM(age)"), merger.NewColumnInfo(4, "COUNT(age)"), "AVG(age)"), }, - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "county"), - NewGroupByColumn(1, "gender"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "county"), + merger.NewColumnInfo(1, "gender"), }, rowsList: func() []*sql.Rows { @@ -442,7 +458,7 @@ func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() { r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) rowsList := []*sql.Rows{r} - merger := NewAggregatorMerger([]aggregator.Aggregator{aggregator.NewSum(aggregator.NewColumnInfo(1, "SUM(id)"))}, []GroupByColumn{NewGroupByColumn(0, "userid")}) + merger := NewAggregatorMerger([]aggregator.Aggregator{aggregator.NewSum(merger.NewColumnInfo(1, "SUM(id)"))}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) rows, err := merger.Merge(context.Background(), rowsList) require.NoError(t, err) userid := 0 @@ -457,7 +473,7 @@ func (ms *MergerSuite) TestAggregatorRows_ScanAndErr() { r, err := ms.mockDB01.QueryContext(context.Background(), query) require.NoError(t, err) rowsList := []*sql.Rows{r} - merger := NewAggregatorMerger([]aggregator.Aggregator{&mockAggregate{}}, []GroupByColumn{NewGroupByColumn(0, "userid")}) + merger := NewAggregatorMerger([]aggregator.Aggregator{&mockAggregate{}}, []merger.ColumnInfo{merger.NewColumnInfo(0, "userid")}) rows, err := merger.Merge(context.Background(), rowsList) require.NoError(t, err) userid := 0 @@ -475,7 +491,7 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndErr() { rowsList func() []*sql.Rows wantErr error aggregators []aggregator.Aggregator - GroupByColumns []GroupByColumn + GroupByColumns []merger.ColumnInfo }{ { name: "有一个aggregator返回error", @@ -500,8 +516,8 @@ func (ms *MergerSuite) TestAggregatorRows_NextAndErr() { &mockAggregate{}, } }(), - GroupByColumns: []GroupByColumn{ - NewGroupByColumn(0, "username"), + GroupByColumns: []merger.ColumnInfo{ + merger.NewColumnInfo(0, "username"), }, wantErr: aggregatorErr, }, @@ -529,14 +545,14 @@ func (ms *MergerSuite) TestAggregatorRows_Columns() { ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(2, 1, 3, 2, 4, 11, "dm")) ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow(3, 1, 4, 3, 5, 12, "xm")) aggregators := []aggregator.Aggregator{ - aggregator.NewAVG(aggregator.NewColumnInfo(0, "SUM(grade)"), aggregator.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), - aggregator.NewSum(aggregator.NewColumnInfo(2, "SUM(id)")), - aggregator.NewMin(aggregator.NewColumnInfo(3, "MIN(id)")), - aggregator.NewMax(aggregator.NewColumnInfo(4, "MAX(id)")), - aggregator.NewCount(aggregator.NewColumnInfo(5, "COUNT(id)")), + aggregator.NewAVG(merger.NewColumnInfo(0, "SUM(grade)"), merger.NewColumnInfo(1, "COUNT(grade)"), "AVG(grade)"), + aggregator.NewSum(merger.NewColumnInfo(2, "SUM(id)")), + aggregator.NewMin(merger.NewColumnInfo(3, "MIN(id)")), + aggregator.NewMax(merger.NewColumnInfo(4, "MAX(id)")), + aggregator.NewCount(merger.NewColumnInfo(5, "COUNT(id)")), } - groupbyColumns := []GroupByColumn{ - NewGroupByColumn(6, "userid"), + groupbyColumns := []merger.ColumnInfo{ + merger.NewColumnInfo(6, "userid"), } merger := NewAggregatorMerger(aggregators, groupbyColumns) dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03} diff --git a/internal/merger/sortmerger/merger.go b/internal/merger/sortmerger/merger.go index 5e2497f..96b2295 100644 --- a/internal/merger/sortmerger/merger.go +++ b/internal/merger/sortmerger/merger.go @@ -20,7 +20,8 @@ import ( "database/sql" "reflect" "sync" - _ "unsafe" + + "github.com/ecodeclub/eorm/internal/merger/utils" "go.uber.org/multierr" @@ -41,9 +42,6 @@ type Ordered interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~float32 | ~float64 | ~string } -//go:linkname convertAssign database/sql.convertAssign -func convertAssign(dest, src any) error - type SortColumn struct { name string order Order @@ -285,7 +283,7 @@ func (r *Rows) Scan(dest ...any) error { } for i := 0; i < len(dest); i++ { - err := convertAssign(dest[i], r.cur.columns[i]) + err := utils.ConvertAssign(dest[i], r.cur.columns[i]) if err != nil { return err } diff --git a/internal/merger/type.go b/internal/merger/type.go index 7c37fe8..32955fb 100644 --- a/internal/merger/type.go +++ b/internal/merger/type.go @@ -33,3 +33,15 @@ type Rows interface { Columns() ([]string, error) Err() error } + +type ColumnInfo struct { + Index int + Name string +} + +func NewColumnInfo(index int, name string) ColumnInfo { + return ColumnInfo{ + Index: index, + Name: name, + } +} diff --git a/internal/merger/utils/convert_Assign.go b/internal/merger/utils/convert_Assign.go new file mode 100644 index 0000000..7b4e4ba --- /dev/null +++ b/internal/merger/utils/convert_Assign.go @@ -0,0 +1,13 @@ +package utils + +import ( + _ "database/sql" + _ "unsafe" +) + +//go:linkname convertAssign database/sql.convertAssign +func convertAssign(dest, src any) error + +func ConvertAssign(dest, src any) error { + return convertAssign(dest, src) +} diff --git a/internal/merger/utils/scan.go b/internal/merger/utils/scan.go new file mode 100644 index 0000000..40c7407 --- /dev/null +++ b/internal/merger/utils/scan.go @@ -0,0 +1,34 @@ +package utils + +import ( + "database/sql" + "reflect" +) + +func Scan(row *sql.Rows) ([]any, error) { + colsInfo, err := row.ColumnTypes() + if err != nil { + return nil, err + } + colsData := make([]any, 0, len(colsInfo)) + // 拿到sql.Rows字段的类型然后初始化 + for _, colInfo := range colsInfo { + typ := colInfo.ScanType() + // sqlite3的驱动返回的是指针。循环的去除指针 + for typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + newData := reflect.New(typ).Interface() + colsData = append(colsData, newData) + } + // 通过Scan赋值 + err = row.Scan(colsData...) + if err != nil { + return nil, err + } + // 去掉reflect.New的指针 + for i := 0; i < len(colsData); i++ { + colsData[i] = reflect.ValueOf(colsData[i]).Elem().Interface() + } + return colsData, nil +}