Skip to content

Commit

Permalink
rows: 同库事务语句合并执行,提前读取所有数据
Browse files Browse the repository at this point in the history
  • Loading branch information
flycash committed Oct 1, 2023
1 parent 07dc416 commit a499ce5
Show file tree
Hide file tree
Showing 20 changed files with 462 additions and 154 deletions.
2 changes: 1 addition & 1 deletion .CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
- [eorm: 分库分表:datasource-简单的分布式事务方案支持](https://github.com/ecodeclub/eorm/pull/204)
- [merger: 使用 sqlx.Scanner 来读取数据](https://github.com/ecodeclub/eorm/pull/216)
- [rows, merger: 使用 sqlx.Rows 作为接口,并重构 merger 包 ](https://github.com/ecodeclub/eorm/pull/217)

- [rows: 同库事务语句合并执行,提前读取所有数据](https://github.com/ecodeclub/eorm/pull/219)
## v0.0.1:
- [Init Project](https://github.com/ecodeclub/eorm/pull/1)
- [Selector Definition](https://github.com/ecodeclub/eorm/pull/2)
Expand Down
34 changes: 14 additions & 20 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type DBOption func(db *DB)

// DB represents a database
type DB struct {
core
baseSession
ds datasource.DataSource
}

Expand All @@ -62,14 +62,6 @@ func UseReflection() DBOption {
}
}

func (db *DB) queryContext(ctx context.Context, q datasource.Query) (*sql.Rows, error) {
return db.ds.Query(ctx, q)
}

func (db *DB) execContext(ctx context.Context, q datasource.Query) (sql.Result, error) {
return db.ds.Exec(ctx, q)
}

// Open 创建一个 ORM 实例
// 注意该实例是一个无状态的对象,你应该尽可能复用它
func Open(driver string, dsn string, opts ...DBOption) (*DB, error) {
Expand All @@ -86,12 +78,15 @@ func OpenDS(driver string, ds datasource.DataSource, opts ...DBOption) (*DB, err
return nil, err
}
orm := &DB{
core: core{
metaRegistry: model.NewMetaRegistry(),
dialect: dl,
// 可以设为默认,因为原本这里也有默认
valCreator: valuer.PrimitiveCreator{
Creator: valuer.NewUnsafeValue,
baseSession: baseSession{
executor: ds,
core: core{
metaRegistry: model.NewMetaRegistry(),
dialect: dl,
// 可以设为默认,因为原本这里也有默认
valCreator: valuer.PrimitiveCreator{
Creator: valuer.NewUnsafeValue,
},
},
},
ds: ds,
Expand All @@ -111,13 +106,12 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
if err != nil {
return nil, err
}
return &Tx{tx: tx, core: db.getCore()}, nil
return &Tx{tx: tx, baseSession: baseSession{
executor: tx,
core: db.core,
}}, nil
}

func (db *DB) Close() error {
return db.ds.Close()
}

func (db *DB) getCore() core {
return db.core
}
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ go 1.20

require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994
github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b
github.com/go-sql-driver/mysql v1.6.0
github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d
github.com/mattn/go-sqlite3 v1.14.15
github.com/stretchr/testify v1.8.1
github.com/valyala/bytebufferpool v1.0.0
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994 h1:4Rp8WrJhISj8GDtnueoD22ygPuppajnCVZuEfRjg6w8=
github.com/ecodeclub/ekit v0.0.4-0.20230904153403-e76aae064994/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs=
github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b h1:T1OvEeJJEOhkrhkg55//A5kzX7lgdeX9gDJuVDahSpw=
github.com/ecodeclub/ekit v0.0.8-0.20231001021557-856d32ae850b/go.mod h1:OqTojKeKFTxeeAAUwNIPKu339SRkX6KAuoK/8A5BCEs=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d h1:kmDgYRZ06UifBqAfew+cj02juQQ3Ko349NzsDIZ0QPw=
github.com/gotomicro/ekit v0.0.0-20230224040531-869798da3c4d/go.mod h1:ISYxgxcx3SOYGm/Hg9+M+pHVhN5G6W7p91/Pn7x6Hz8=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
Expand Down
7 changes: 1 addition & 6 deletions internal/datasource/transaction/delay_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,6 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() {
rows := s.mockMaster2.NewRows([]string{"order_id", "item_id", "using_col1", "using_col2"})
s.mockMaster2.ExpectQuery(regexp.QuoteMeta("SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_2` WHERE (`order_id`=?) OR (`order_id`=?);SELECT `order_id`,`item_id`,`using_col1`,`using_col2` FROM `order_detail_db_1`.`order_detail_tab_1` WHERE (`order_id`=?) OR (`order_id`=?);")).
WithArgs(199, 299, 199, 299).WillReturnRows(rows)

queryVal := s.findTgt(t, values)
var wantOds []*test.OrderDetail
assert.ElementsMatch(t, wantOds, queryVal)
},
},
}
Expand All @@ -496,10 +492,9 @@ func (s *TestDelayTxTestSuite) TestExecute_Commit_Or_Rollback() {
tx, err := tc.txFunc()
require.NoError(t, err)

// TODO GetMultiV2 待将 table 维度改成 db 维度
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").NEQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
Expand Down
3 changes: 1 addition & 2 deletions internal/datasource/transaction/transaction_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ func (s *ShardingTransactionSuite) findTgt(t *testing.T, values []*test.OrderDet
od = values[i]
pre = pre.Or(eorm.C(s.shardingKey).EQ(od.OrderId))
}
// TODO GetMultiV2 待将 table 维度改成 db 维度
querySet, err := eorm.NewShardingSelector[test.OrderDetail](s.shardingDB).
Where(pre).GetMultiV2(masterslave.UseMaster(context.Background()))
Where(pre).GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
return querySet
}
Expand Down
6 changes: 3 additions & 3 deletions internal/integration/sharding_delay_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ func (s *ShardingDelayTxTestSuite) TestDoubleShardingSelect() {
defer tx.Commit()
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").NEQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)

querySet, err = eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").NEQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)
})
Expand Down Expand Up @@ -228,7 +228,7 @@ func (s *ShardingDelayTxTestSuite) TestShardingSelectUpdateInsert_Commit_Or_Roll
tx := tc.txFunc(t)
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").NEQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)

Expand Down
8 changes: 4 additions & 4 deletions internal/integration/sharding_single_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ func (s *ShardingSingleTxTestSuite) TestDoubleShardingSelect() {
defer tx.Commit()
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").EQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)

querySet, err = eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").EQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)
})
Expand Down Expand Up @@ -137,7 +137,7 @@ func (s *ShardingSingleTxTestSuite) TestShardingSelectInsert_Commit_Or_Rollback(
tx := tc.txFunc(t)
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").EQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)
res := eorm.NewShardingInsert[test.OrderDetail](tx).
Expand Down Expand Up @@ -220,7 +220,7 @@ func (s *ShardingSingleTxTestSuite) TestShardingSelectUpdate_Commit_Or_Rollback(
tx := tc.txFunc(t)
querySet, err := eorm.NewShardingSelector[test.OrderDetail](tx).
Where(eorm.C("OrderId").EQ(123)).
GetMultiV2(masterslave.UseMaster(context.Background()))
GetMulti(masterslave.UseMaster(context.Background()))
require.NoError(t, err)
assert.ElementsMatch(t, tc.querySet, querySet)
res := eorm.NewShardingUpdater[test.OrderDetail](tx).Update(tc.target).
Expand Down
4 changes: 2 additions & 2 deletions internal/merger/groupby_merger/aggregator_merger.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import (

"go.uber.org/multierr"

"github.com/ecodeclub/ekit/mapx"
"github.com/ecodeclub/eorm/internal/merger"
"github.com/ecodeclub/eorm/internal/merger/aggregatemerger/aggregator"
"github.com/ecodeclub/eorm/internal/merger/internal/errs"
"github.com/gotomicro/ekit/mapx"
)

type AggregatorMerger struct {
Expand Down Expand Up @@ -109,7 +109,7 @@ func (a *AggregatorMerger) getCols(rowsList []rows.Rows) (*mapx.TreeMap[Key, [][
val, ok := treeMap.Get(key)
if ok {
val = append(val, colData)
err = treeMap.Set(key, val)
err = treeMap.Put(key, val)
if err != nil {
return nil, nil, err
}
Expand Down
1 change: 0 additions & 1 deletion internal/merger/internal/errs/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ var (
ErrMergerAggregateHasEmptyRows = errors.New("merger: 聚合函数计算时rowsList有一个或多个为空")
ErrMergerInvalidAggregateColumnIndex = errors.New("merger: ColumnInfo的index不合法")
ErrMergerAggregateFuncNotFound = errors.New("merger: 聚合函数方法未找到")
ErrMergerNullable = errors.New("merger: 接收数据的类型需要为sql.Nullable")
)

func NewRepeatSortColumn(column string) error {
Expand Down
10 changes: 10 additions & 0 deletions internal/rows/convert_assign.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package rows

import (
"database/sql"
"database/sql/driver"
_ "unsafe"
)
Expand All @@ -31,5 +32,14 @@ func ConvertAssign(dest, src any) error {
return err
}
}
// 预处理一下 sqlConvertAssign 不支持的转换,遇到一个加一个
switch sv := src.(type) {
case sql.RawBytes:
switch dv := dest.(type) {
case *string:
*dv = string(sv)
return nil
}
}
return sqlConvertAssign(dest, src)
}
88 changes: 88 additions & 0 deletions internal/rows/data_rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package rows

import (
"database/sql"

"github.com/ecodeclub/eorm/internal/errs"
)

var _ Rows = (*DataRows)(nil)

// DataRows 直接传入数据,伪装成了一个 Rows
// 非线程安全实现
type DataRows struct {
data [][]any
len int
columns []string
columnTypes []*sql.ColumnType
// 第几行
idx int
}

func (*DataRows) NextResultSet() bool {
return false
}

func (d *DataRows) ColumnTypes() ([]*sql.ColumnType, error) {
return d.columnTypes, nil
}

func NewDataRows(data [][]any, columns []string, columnTypes []*sql.ColumnType) *DataRows {
// 这里并没有什么必要检查 data 和 columns 的输入
// 因为只有在很故意的情况下,data 和 columns 才可能会有问题
return &DataRows{
data: data,
len: len(data),
columns: columns,
idx: -1,
columnTypes: columnTypes,
}
}

func (d *DataRows) Next() bool {
if d.idx >= d.len-1 {
return false
}
d.idx++
return true
}

func (d *DataRows) Scan(dest ...any) error {
// 不需要检测,作为内部代码我们可以预期用户会主动控制
data := d.data[d.idx]
if len(data) != len(dest) {
return errs.NewErrScanWrongDestinationArguments(len(data), len(dest))
}
for idx, dst := range dest {
if err := ConvertAssign(dst, data[idx]); err != nil {
return err
}
}
return nil
}

func (*DataRows) Close() error {
return nil
}

func (d *DataRows) Columns() ([]string, error) {
return d.columns, nil
}

func (*DataRows) Err() error {
return nil
}
Loading

0 comments on commit a499ce5

Please sign in to comment.