Skip to content

Commit

Permalink
sortmerge去除传入类型信息
Browse files Browse the repository at this point in the history
  • Loading branch information
juniaoshaonian committed Apr 15, 2023
1 parent 7b1fd2b commit aae597a
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 68 deletions.
28 changes: 14 additions & 14 deletions internal/merger/pagedmerger/merger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (ms *MergerSuite) TestMerger_New() {
}
for _, tc := range testcases {
ms.T().Run(tc.name, func(t *testing.T) {
m, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
m, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
require.NoError(t, err)
limitMerger, err := NewMerger(m, tc.offset, tc.limit)
assert.Equal(t, tc.wantErr, err)
Expand All @@ -141,7 +141,7 @@ func (ms *MergerSuite) TestMerger_Merge() {
{
name: "limitMerger里的Merger的Merge出错",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
return []*sql.Rows{}
Expand All @@ -156,7 +156,7 @@ func (ms *MergerSuite) TestMerger_Merge() {
{
name: "初始化游标出错",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
cols := []string{"id", "name", "address"}
Expand All @@ -183,7 +183,7 @@ func (ms *MergerSuite) TestMerger_Merge() {
{
name: "offset的值超过返回的数据行数",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
cols := []string{"id", "name", "address"}
Expand All @@ -209,7 +209,7 @@ func (ms *MergerSuite) TestMerger_Merge() {
{
name: "超时",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
cols := []string{"id", "name", "address"}
Expand Down Expand Up @@ -266,7 +266,7 @@ func (ms *MergerSuite) TestMerger_NextAndScan() {
{
name: "limit的行数超过了返回的总行数,",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
cols := []string{"id", "name", "address"}
Expand Down Expand Up @@ -316,7 +316,7 @@ func (ms *MergerSuite) TestMerger_NextAndScan() {
{
name: "limit 行数小于返回的总行数",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
cols := []string{"id", "name", "address"}
Expand Down Expand Up @@ -351,7 +351,7 @@ func (ms *MergerSuite) TestMerger_NextAndScan() {
{
name: "offset超过sqlRows列表返回的总行数",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
cols := []string{"id", "name", "address"}
Expand All @@ -375,7 +375,7 @@ func (ms *MergerSuite) TestMerger_NextAndScan() {
{
name: "offset 的值为0",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
cols := []string{"id", "name", "address"}
Expand Down Expand Up @@ -462,7 +462,7 @@ func (ms *MergerSuite) TestRows_NextAndErr() {
{
name: "有sql.Rows返回错误",
getMerger: func() (merger.Merger, error) {
return sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
return sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
},
GetRowsList: func() []*sql.Rows {
cols := []string{"id", "name", "address"}
Expand Down Expand Up @@ -508,7 +508,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() {
r, err := ms.mockDB01.QueryContext(context.Background(), query)
require.NoError(t, err)
rowsList := []*sql.Rows{r}
merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
require.NoError(t, err)
limitMerger, err := NewMerger(merger, 0, 1)
require.NoError(t, err)
Expand All @@ -525,7 +525,7 @@ func (ms *MergerSuite) TestRows_ScanAndErr() {
r, err := ms.mockDB01.QueryContext(context.Background(), query)
require.NoError(t, err)
rowsList := []*sql.Rows{r}
merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
require.NoError(t, err)
limitMerger, err := NewMerger(merger, 0, 1)
require.NoError(t, err)
Expand All @@ -545,7 +545,7 @@ func (ms *MergerSuite) TestRows_Close() {
ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1"))
ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2").AddRow("5").CloseError(newCloseMockErr("db02")))
ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4").CloseError(newCloseMockErr("db03")))
merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
require.NoError(ms.T(), err)
limitMerger, err := NewMerger(merger, 1, 6)
require.NoError(ms.T(), err)
Expand Down Expand Up @@ -596,7 +596,7 @@ func (ms *MergerSuite) TestRows_Columns() {
ms.mock01.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("1"))
ms.mock02.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("2"))
ms.mock03.ExpectQuery("SELECT *").WillReturnRows(sqlmock.NewRows(cols).AddRow("3").AddRow("4"))
merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn[int]("id", sortmerger.ASC))
merger, err := sortmerger.NewMerger(sortmerger.NewSortColumn("id", sortmerger.ASC))
require.NoError(ms.T(), err)
limitMerger, err := NewMerger(merger, 0, 10)
require.NoError(ms.T(), err)
Expand Down
16 changes: 8 additions & 8 deletions internal/merger/sortmerger/heap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestHeap(t *testing.T) {
})
},
sortCols: func() sortColumns {
sortCols, err := newSortColumns(NewSortColumn[int]("id", ASC))
sortCols, err := newSortColumns(NewSortColumn("id", ASC))
require.NoError(t, err)
return sortCols
},
Expand All @@ -98,7 +98,7 @@ func TestHeap(t *testing.T) {
})
},
sortCols: func() sortColumns {
sortCols, err := newSortColumns(NewSortColumn[int]("id", DESC))
sortCols, err := newSortColumns(NewSortColumn("id", DESC))
require.NoError(t, err)
return sortCols
},
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestHeap(t *testing.T) {
})
},
sortCols: func() sortColumns {
sortCols, err := newSortColumns(NewSortColumn[int]("id", ASC), NewSortColumn[string]("name", DESC), NewSortColumn[int]("age", ASC))
sortCols, err := newSortColumns(NewSortColumn("id", ASC), NewSortColumn("name", DESC), NewSortColumn("age", ASC))
require.NoError(t, err)
return sortCols
},
Expand Down Expand Up @@ -190,7 +190,7 @@ func TestHeap(t *testing.T) {
})
},
sortCols: func() sortColumns {
sortCols, err := newSortColumns(NewSortColumn[int]("id", DESC), NewSortColumn[string]("name", ASC), NewSortColumn[int]("age", DESC))
sortCols, err := newSortColumns(NewSortColumn("id", DESC), NewSortColumn("name", ASC), NewSortColumn("age", DESC))
require.NoError(t, err)
return sortCols
},
Expand Down Expand Up @@ -236,7 +236,7 @@ func TestHeap(t *testing.T) {
})
},
sortCols: func() sortColumns {
sortCols, err := newSortColumns(NewSortColumn[int]("id", ASC), NewSortColumn[string]("name", ASC), NewSortColumn[int]("age", DESC))
sortCols, err := newSortColumns(NewSortColumn("id", ASC), NewSortColumn("name", ASC), NewSortColumn("age", DESC))
require.NoError(t, err)
return sortCols
},
Expand Down Expand Up @@ -282,7 +282,7 @@ func TestHeap(t *testing.T) {
})
},
sortCols: func() sortColumns {
sortCols, err := newSortColumns(NewSortColumn[int]("id", DESC), NewSortColumn[string]("name", DESC), NewSortColumn[int]("age", ASC))
sortCols, err := newSortColumns(NewSortColumn("id", DESC), NewSortColumn("name", DESC), NewSortColumn("age", ASC))
require.NoError(t, err)
return sortCols
},
Expand Down Expand Up @@ -328,7 +328,7 @@ func TestHeap(t *testing.T) {
})
},
sortCols: func() sortColumns {
sortCols, err := newSortColumns(NewSortColumn[int]("id", DESC), NewSortColumn[string]("name", DESC), NewSortColumn[int]("age", DESC))
sortCols, err := newSortColumns(NewSortColumn("id", DESC), NewSortColumn("name", DESC), NewSortColumn("age", DESC))
require.NoError(t, err)
return sortCols
},
Expand Down Expand Up @@ -374,7 +374,7 @@ func TestHeap(t *testing.T) {
})
},
sortCols: func() sortColumns {
sortCols, err := newSortColumns(NewSortColumn[int]("id", ASC), NewSortColumn[string]("name", ASC), NewSortColumn[int]("age", ASC))
sortCols, err := newSortColumns(NewSortColumn("id", ASC), NewSortColumn("name", ASC), NewSortColumn("age", ASC))
require.NoError(t, err)
return sortCols
},
Expand Down
18 changes: 9 additions & 9 deletions internal/merger/sortmerger/merger.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,12 @@ func convertAssign(dest, src any) error
type SortColumn struct {
name string
order Order
typ reflect.Type
}

func NewSortColumn[T Ordered](colName string, order Order) SortColumn {
var t T
typ := reflect.TypeOf(t)
func NewSortColumn(colName string, order Order) SortColumn {
return SortColumn{
name: colName,
order: order,
typ: typ,
}
}

Expand All @@ -70,8 +66,8 @@ func (s sortColumns) Has(name string) bool {
return ok
}

func (s sortColumns) Find(name string) (SortColumn, int) {
return s.columns[s.colMap[name]], s.colMap[name]
func (s sortColumns) Find(name string) int {
return s.colMap[name]
}

func (s sortColumns) Get(index int) SortColumn {
Expand Down Expand Up @@ -198,8 +194,12 @@ func newNode(row *sql.Rows, sortCols sortColumns, index int) (*node, error) {
for _, colInfo := range colsInfo {
colName := colInfo.Name()
if sortCols.Has(colName) {
sortCol, sortIndex := sortCols.Find(colName)
sortColumn := reflect.New(sortCol.typ).Interface()
sortIndex := sortCols.Find(colName)
colType := colInfo.ScanType()
for colType.Kind() == reflect.Ptr {
colType = colType.Elem()
}
sortColumn := reflect.New(colType).Interface()
sortColumns[sortIndex] = sortColumn
columns = append(columns, sortColumn)
} else {
Expand Down
Loading

0 comments on commit aae597a

Please sign in to comment.