Skip to content

Commit

Permalink
完成ColumnType重构及测试, 去掉NewColumnInfo方法直接使用字面量
Browse files Browse the repository at this point in the history
  • Loading branch information
longyue0521 committed Jun 10, 2024
1 parent d3350ca commit b9dba8d
Show file tree
Hide file tree
Showing 13 changed files with 480 additions and 248 deletions.
8 changes: 4 additions & 4 deletions internal/merger/factory/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (q QuerySpec) validateGroupBy() error {
return fmt.Errorf("%w: groupby %v", ErrInvalidColumnInfo, c.Name)
}
// 清除ASC
c.ASC = false
c.Order = merger.DESC
if !slice.Contains(q.Select, c) {
return fmt.Errorf("%w: groupby %v", ErrColumnNotFoundInSelectList, c.Name)
}
Expand Down Expand Up @@ -140,7 +140,7 @@ func (q QuerySpec) validateOrderBy() error {
return fmt.Errorf("%w: orderby %v", ErrInvalidColumnInfo, c.Name)
}
// 清除ASC
c.ASC = false
c.Order = merger.DESC
if !slice.Contains(q.Select, c) {
return fmt.Errorf("%w: orderby %v", ErrColumnNotFoundInSelectList, c.Name)
}
Expand Down Expand Up @@ -207,12 +207,12 @@ func newOrderByMerger(origin, target QuerySpec) (merger.Merger, error) {
for i := 0; i < len(target.OrderBy); i++ {
c := target.OrderBy[i]
if i < len(origin.OrderBy) && strings.ToUpper(origin.OrderBy[i].AggregateFunc) == "AVG" {
s := sortmerger.NewSortColumn(origin.OrderBy[i].SelectName(), sortmerger.Order(origin.OrderBy[i].ASC))
s := sortmerger.NewSortColumn(origin.OrderBy[i].SelectName(), sortmerger.Order(origin.OrderBy[i].Order))
columns = append(columns, s)
i++
continue
}
s := sortmerger.NewSortColumn(c.SelectName(), sortmerger.Order(c.ASC))
s := sortmerger.NewSortColumn(c.SelectName(), sortmerger.Order(c.Order))
columns = append(columns, s)
}

Expand Down
40 changes: 20 additions & 20 deletions internal/merger/factory/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func TestNew(t *testing.T) {
{
Index: 0, // 索引排序? amount没有出现在SELECT子句,出现在orderBy子句中
Name: "amount",
ASC: true,
Order: merger.ASC,
},
},
},
Expand Down Expand Up @@ -748,7 +748,7 @@ func (s *factoryTestSuite) TestSELECT() {
{
Index: 0,
Name: "`ctime`",
ASC: true,
Order: merger.ASC,
},
},
},
Expand All @@ -764,7 +764,7 @@ func (s *factoryTestSuite) TestSELECT() {
{
Index: 0,
Name: "`ctime`",
ASC: true,
Order: merger.ASC,
},
},
},
Expand Down Expand Up @@ -803,13 +803,13 @@ func (s *factoryTestSuite) TestSELECT() {
Index: 0,
Name: "`user_id`",
Alias: "`uid`",
ASC: true,
Order: merger.ASC,
},
{
Index: 1,
Name: "`order_id`",
Alias: "`oid`",
ASC: false,
Order: merger.DESC,
},
},
},
Expand All @@ -832,13 +832,13 @@ func (s *factoryTestSuite) TestSELECT() {
Index: 0,
Name: "`user_id`",
Alias: "`uid`",
ASC: true,
Order: merger.ASC,
},
{
Index: 1,
Name: "`order_id`",
Alias: "`oid`",
ASC: false,
Order: merger.DESC,
},
},
},
Expand Down Expand Up @@ -900,7 +900,7 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "AVG",
Alias: "`avg_amt`",
ASC: true,
Order: true,
},
},
},
Expand Down Expand Up @@ -985,7 +985,7 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "COUNT",
Alias: "`cnt_amt`",
ASC: true,
Order: true,
},
},
},
Expand All @@ -1005,7 +1005,7 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "COUNT",
Alias: "`cnt_amt`",
ASC: true,
Order: true,
},
},
},
Expand Down Expand Up @@ -1077,7 +1077,7 @@ func (s *factoryTestSuite) TestSELECT() {
{
Index: 1,
Name: "`ctime`",
ASC: true,
Order: merger.ASC,
},
},
},
Expand All @@ -1093,7 +1093,7 @@ func (s *factoryTestSuite) TestSELECT() {
{
Index: 1,
Name: "`ctime`",
ASC: true,
Order: merger.ASC,
},
},
},
Expand Down Expand Up @@ -1588,13 +1588,13 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "SUM",
Alias: "`total_amt`",
ASC: true,
Order: true,
},
{
Index: 0,
Name: "`user_id`",
Alias: "`uid`",
ASC: false,
Order: merger.DESC,
},
},
},
Expand Down Expand Up @@ -1636,13 +1636,13 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "SUM",
Alias: "`total_amt`",
ASC: true,
Order: true,
},
{
Index: 0,
Name: "`user_id`",
Alias: "`uid`",
ASC: false,
Order: merger.DESC,
},
},
},
Expand Down Expand Up @@ -1842,7 +1842,7 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "SUM",
Alias: "`total_amt`",
ASC: false,
Order: false,
},
},
Limit: 2,
Expand Down Expand Up @@ -1876,7 +1876,7 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "SUM",
Alias: "`total_amt`",
ASC: false,
Order: false,
},
},
Limit: 2,
Expand Down Expand Up @@ -1955,7 +1955,7 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "SUM",
Alias: "`total_amt`",
ASC: true,
Order: true,
},
},
Limit: 6,
Expand Down Expand Up @@ -1999,7 +1999,7 @@ func (s *factoryTestSuite) TestSELECT() {
Name: "`amount`",
AggregateFunc: "SUM",
Alias: "`total_amt`",
ASC: true,
Order: true,
},
},
Limit: 6,
Expand Down
4 changes: 2 additions & 2 deletions internal/merger/internal/aggregatemerger/merger.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ type Rows struct {
}

func (r *Rows) ColumnTypes() ([]*sql.ColumnType, error) {
r.mu.RLock()
defer r.mu.RUnlock()
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return nil, fmt.Errorf("%w", errs.ErrMergerRowsClosed)
}
Expand Down
54 changes: 27 additions & 27 deletions internal/merger/internal/aggregatemerger/merger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
}(),
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}),
}
},
},
Expand Down Expand Up @@ -179,7 +179,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
}(),
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}),
}
},
},
Expand Down Expand Up @@ -209,7 +209,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
}(),
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewMax(merger.NewColumnInfo(0, "MAX(id)")),
aggregator.NewMax(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "MAX"}),
}
},
},
Expand Down Expand Up @@ -238,7 +238,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
}(),
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewMin(merger.NewColumnInfo(0, "MIN(id)")),
aggregator.NewMin(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "MIN"}),
}
},
},
Expand Down Expand Up @@ -267,7 +267,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
}(),
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")),
aggregator.NewCount(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "COUNT"}),
}
},
},
Expand Down Expand Up @@ -336,10 +336,10 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
}(),
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")),
aggregator.NewMax(merger.NewColumnInfo(1, "MAX(id)")),
aggregator.NewMin(merger.NewColumnInfo(2, "MIN(id)")),
aggregator.NewSum(merger.NewColumnInfo(3, "SUM(id)")),
aggregator.NewCount(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "COUNT"}),
aggregator.NewMax(merger.ColumnInfo{Index: 1, Name: "id", AggregateFunc: "MAX"}),
aggregator.NewMin(merger.ColumnInfo{Index: 2, Name: "id", AggregateFunc: "MIN"}),
aggregator.NewSum(merger.ColumnInfo{Index: 3, Name: "id", AggregateFunc: "SUM"}),
aggregator.NewAVG(
merger.ColumnInfo{Index: 4, Name: `grade`, AggregateFunc: "AVG"},
merger.ColumnInfo{Index: 5, Name: `grade`, AggregateFunc: "SUM"},
Expand Down Expand Up @@ -382,16 +382,16 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
merger.ColumnInfo{Index: 1, Name: `grade`, AggregateFunc: "SUM"},
merger.ColumnInfo{Index: 2, Name: `grade`, AggregateFunc: "COUNT"},
),
aggregator.NewSum(merger.NewColumnInfo(3, "SUM(grade)")),
aggregator.NewSum(merger.ColumnInfo{Index: 3, Name: "grade", AggregateFunc: "SUM"}),
aggregator.NewAVG(
merger.ColumnInfo{Index: 4, Name: `grade`, AggregateFunc: "AVG"},
merger.ColumnInfo{Index: 6, Name: `grade`, AggregateFunc: "SUM"},
merger.ColumnInfo{Index: 5, Name: `grade`, AggregateFunc: "COUNT"},
),
aggregator.NewMin(merger.NewColumnInfo(7, "MIN(id)")),
aggregator.NewMin(merger.NewColumnInfo(8, "MIN(userid)")),
aggregator.NewMax(merger.NewColumnInfo(9, "MAX(id)")),
aggregator.NewCount(merger.NewColumnInfo(10, "COUNT(id)")),
aggregator.NewMin(merger.ColumnInfo{Index: 7, Name: "id", AggregateFunc: "MIN"}),
aggregator.NewMin(merger.ColumnInfo{Index: 8, Name: "userid", AggregateFunc: "MIN"}),
aggregator.NewMax(merger.ColumnInfo{Index: 9, Name: "id", AggregateFunc: "MAX"}),
aggregator.NewCount(merger.ColumnInfo{Index: 10, Name: "id", AggregateFunc: "COUNT"}),
}
},
},
Expand Down Expand Up @@ -426,7 +426,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
wantErr: errs.ErrMergerAggregateHasEmptyRows,
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}),
}
},
},
Expand Down Expand Up @@ -458,7 +458,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
wantErr: errs.ErrMergerAggregateHasEmptyRows,
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}),
}
},
},
Expand Down Expand Up @@ -490,7 +490,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
wantErr: errs.ErrMergerAggregateHasEmptyRows,
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}),
}
},
},
Expand All @@ -516,7 +516,7 @@ func (ms *MergerSuite) TestRows_NextAndScan() {
wantErr: errs.ErrMergerAggregateHasEmptyRows,
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")),
aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}),
}
},
},
Expand Down Expand Up @@ -570,7 +570,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() {
},
aggregators: func() []aggregator.Aggregator {
return []aggregator.Aggregator{
aggregator.NewCount(merger.NewColumnInfo(0, "COUNT(id)")),
aggregator.NewCount(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "COUNT"}),
}
}(),
wantErr: nextMockErr,
Expand Down Expand Up @@ -622,7 +622,7 @@ func (ms *MergerSuite) TestRows_Close() {
ms.mock01.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(1))
ms.mock02.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(2).CloseError(newCloseMockErr("db02")))
ms.mock03.ExpectQuery(targetSQL).WillReturnRows(sqlmock.NewRows(cols).AddRow(3).CloseError(newCloseMockErr("db03")))
m := NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")))
m := NewMerger(aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}))
dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
rowsList := make([]rows.Rows, 0, len(dbs))
for _, db := range dbs {
Expand Down Expand Up @@ -676,10 +676,10 @@ func (ms *MergerSuite) TestRows_Columns() {
merger.ColumnInfo{Index: 1, Name: `grade`, AggregateFunc: "SUM"},
merger.ColumnInfo{Index: 2, Name: `grade`, AggregateFunc: "COUNT"},
),
aggregator.NewSum(merger.NewColumnInfo(3, "SUM(id)")),
aggregator.NewMin(merger.NewColumnInfo(4, "MIN(id)")),
aggregator.NewMax(merger.NewColumnInfo(5, "MAX(id)")),
aggregator.NewCount(merger.NewColumnInfo(6, "COUNT(id)")),
aggregator.NewSum(merger.ColumnInfo{Index: 3, Name: "id", AggregateFunc: "SUM"}),
aggregator.NewMin(merger.ColumnInfo{Index: 4, Name: "id", AggregateFunc: "MIN"}),
aggregator.NewMax(merger.ColumnInfo{Index: 5, Name: "id", AggregateFunc: "MAX"}),
aggregator.NewCount(merger.ColumnInfo{Index: 6, Name: "id", AggregateFunc: "COUNT"}),
}
m := NewMerger(aggregators...)
dbs := []*sql.DB{ms.mockDB01, ms.mockDB02, ms.mockDB03}
Expand Down Expand Up @@ -720,7 +720,7 @@ func (ms *MergerSuite) TestMerger_Merge() {
{
name: "超时",
merger: func() *Merger {
return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")))
return NewMerger(aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}))
},
ctx: func() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(), 0)
Expand All @@ -740,7 +740,7 @@ func (ms *MergerSuite) TestMerger_Merge() {
{
name: "sqlRows列表元素个数为0",
merger: func() *Merger {
return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")))
return NewMerger(aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}))
},
ctx: func() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -754,7 +754,7 @@ func (ms *MergerSuite) TestMerger_Merge() {
{
name: "sqlRows列表有nil",
merger: func() *Merger {
return NewMerger(aggregator.NewSum(merger.NewColumnInfo(0, "SUM(id)")))
return NewMerger(aggregator.NewSum(merger.ColumnInfo{Index: 0, Name: "id", AggregateFunc: "SUM"}))
},
ctx: func() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
Loading

0 comments on commit b9dba8d

Please sign in to comment.