Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: 分库分表:datasource-简单的分布式事务方案支持 #210

Merged
merged 13 commits into from
Jul 10, 2023
1 change: 1 addition & 0 deletions .CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
- [eorm: 分库分表:Inserter 支持分库分表](https://github.com/ecodeclub/eorm/pull/200)
- [eorm: ShardingInserter 修改为表维度执行](https://github.com/ecodeclub/eorm/pull/211)
- [eorm: 分库分表:ShardingUpdater 实现](https://github.com/ecodeclub/eorm/pull/201)
- [eorm: 分库分表:datasource-简单的分布式事务方案支持](https://github.com/ecodeclub/eorm/pull/204)

## v0.0.1:
- [Init Project](https://github.com/ecodeclub/eorm/pull/1)
Expand Down
2 changes: 1 addition & 1 deletion db.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func DBWithMiddlewares(ms ...Middleware) DBOption {
}
}

func DBOptionWithMetaRegistry(r model.MetaRegistry) DBOption {
func DBWithMetaRegistry(r model.MetaRegistry) DBOption {
return func(db *DB) {
db.metaRegistry = r
}
Expand Down
37 changes: 33 additions & 4 deletions internal/datasource/cluster/cluster_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ import (
"database/sql"
"fmt"

"github.com/ecodeclub/eorm/internal/datasource/transaction"

"github.com/ecodeclub/eorm/internal/datasource"
"github.com/ecodeclub/eorm/internal/datasource/masterslave"
"github.com/ecodeclub/eorm/internal/errs"
"go.uber.org/multierr"
)

var _ datasource.TxBeginner = &clusterDB{}
var _ datasource.DataSource = &clusterDB{}
var _ datasource.Finder = &clusterDB{}

// clusterDB 以 DB 名称作为索引目标数据库
type clusterDB struct {
Expand All @@ -34,17 +38,17 @@ type clusterDB struct {
}

func (c *clusterDB) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) {
ms, ok := c.masterSlavesDBs[query.DB]
if !ok {
return nil, errs.ErrNotFoundTargetDB
ms, err := c.getTgt(query)
if err != nil {
return nil, err
}
return ms.Query(ctx, query)
}

func (c *clusterDB) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) {
ms, ok := c.masterSlavesDBs[query.DB]
if !ok {
return nil, errs.ErrNotFoundTargetDB
return nil, errs.NewErrNotFoundTargetDB(query.DB)
}
return ms.Exec(ctx, query)
}
Expand All @@ -60,6 +64,31 @@ func (c *clusterDB) Close() error {
return err
}

func (c *clusterDB) FindTgt(_ context.Context, query datasource.Query) (datasource.TxBeginner, error) {
db, err := c.getTgt(query)
if err != nil {
return nil, err
}
return db, nil
}

func (c *clusterDB) getTgt(query datasource.Query) (*masterslave.MasterSlavesDB, error) {
db, ok := c.masterSlavesDBs[query.DB]
if !ok {
return nil, errs.NewErrNotFoundTargetDB(query.DB)
}
return db, nil
}

func (c *clusterDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) {
facade, err := transaction.NewTxFacade(ctx, c)
if err != nil {
return nil, err
}

return facade.BeginTx(ctx, opts)
}

func NewClusterDB(ms map[string]*masterslave.MasterSlavesDB) datasource.DataSource {
return &clusterDB{masterSlavesDBs: ms}
}
4 changes: 2 additions & 2 deletions internal/datasource/cluster/cluster_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (c *ClusterSuite) TestClusterDbQuery() {
masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_0": db}
return masterSlaves
}(),
wantErr: errs.ErrNotFoundTargetDB,
wantErr: errs.NewErrNotFoundTargetDB("order_db_1"),
},
{
name: "select default use slave",
Expand Down Expand Up @@ -219,7 +219,7 @@ func (c *ClusterSuite) TestClusterDbExec() {
masterSlaves := map[string]*masterslave.MasterSlavesDB{"order_db_0": db}
return masterSlaves
}(),
wantErr: errs.ErrNotFoundTargetDB,
wantErr: errs.NewErrNotFoundTargetDB("order_db_1"),
},
{
name: "null slave",
Expand Down
41 changes: 34 additions & 7 deletions internal/datasource/shardingsource/sharding_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"database/sql"
"fmt"

"github.com/ecodeclub/eorm/internal/datasource/transaction"

"github.com/ecodeclub/eorm/internal/datasource"
"go.uber.org/multierr"

Expand All @@ -27,29 +29,54 @@ import (

var _ datasource.TxBeginner = &ShardingDataSource{}
var _ datasource.DataSource = &ShardingDataSource{}
var _ datasource.Finder = &ShardingDataSource{}

type ShardingDataSource struct {
sources map[string]datasource.DataSource
}

func (s *ShardingDataSource) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) {
ds, ok := s.sources[query.Datasource]
if !ok {
return nil, errs.ErrNotFoundTargetDataSource
ds, err := s.getTgt(query)
if err != nil {
return nil, err
}
return ds.Query(ctx, query)
}

func (s *ShardingDataSource) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) {
ds, err := s.getTgt(query)
if err != nil {
return nil, err
}
return ds.Exec(ctx, query)
}

func (s *ShardingDataSource) FindTgt(ctx context.Context, query datasource.Query) (datasource.TxBeginner, error) {
ds, err := s.getTgt(query)
if err != nil {
return nil, err
}
f, ok := ds.(datasource.Finder)
if !ok {
return nil, errs.NewErrNotCompleteFinder(query.Datasource)
}
return f.FindTgt(ctx, query)
}

func (s *ShardingDataSource) getTgt(query datasource.Query) (datasource.DataSource, error) {
ds, ok := s.sources[query.Datasource]
if !ok {
return nil, errs.ErrNotFoundTargetDataSource
return nil, errs.NewErrNotFoundTargetDataSource(query.Datasource)
}
return ds.Exec(ctx, query)
return ds, nil
}

func (*ShardingDataSource) BeginTx(_ context.Context, _ *sql.TxOptions) (datasource.Tx, error) {
panic("`BeginTx` must be completed")
func (s *ShardingDataSource) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) {
facade, err := transaction.NewTxFacade(ctx, s)
if err != nil {
return nil, err
}
return facade.BeginTx(ctx, opts)
}

func NewShardingDataSource(m map[string]datasource.DataSource) datasource.DataSource {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (c *ShardingDataSourceSuite) TestClusterDbQuery() {
DB: "db_0",
Datasource: "2.db.cluster.company.com:3306",
},
wantErr: errs.ErrNotFoundTargetDataSource,
wantErr: errs.NewErrNotFoundTargetDataSource("2.db.cluster.company.com:3306"),
},
{
name: "cluster0 select default use slave",
Expand Down Expand Up @@ -280,7 +280,7 @@ func (c *ShardingDataSourceSuite) TestClusterDbExec() {
DB: "db_0",
Datasource: "2.db.cluster.company.com:3306",
},
wantErr: errs.ErrNotFoundTargetDataSource,
wantErr: errs.NewErrNotFoundTargetDataSource("2.db.cluster.company.com:3306"),
},
{
name: "cluster0 exec",
Expand Down
27 changes: 24 additions & 3 deletions internal/datasource/single/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ var _ datasource.DataSource = &DB{}

// DB represents a database
type DB struct {
db *sql.DB
db *sql.DB
multiStatements bool
}

func (db *DB) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) {
Expand All @@ -42,12 +43,22 @@ func (db *DB) Exec(ctx context.Context, query datasource.Query) (sql.Result, err
return db.db.ExecContext(ctx, query.SQL, query.Args...)
}

func OpenDB(driver string, dsn string) (*DB, error) {
func OpenDB(driver string, dsn string, opts ...Option) (*DB, error) {
res := &DB{}
for _, o := range opts {
o(res)
}

if res.multiStatements {
dsn = dsn + "?multiStatements=true"
}

db, err := sql.Open(driver, dsn)
if err != nil {
return nil, err
}
return &DB{db: db}, nil
res.db = db
return res, nil
}

func NewDB(db *sql.DB) *DB {
Expand Down Expand Up @@ -77,3 +88,13 @@ func (db *DB) Wait() error {
func (db *DB) Close() error {
return db.db.Close()
}

type Option func(db *DB)

// DBWithMultiStatements 在创建连接时 加入参数 multiStatements=true,允许多条语句查询
// 当然 multi statements 可能会增加sql注入的风险,故该操作只允许一次性业务操作,连接使用完成后需要关闭连接
func DBWithMultiStatements(m bool) Option {
return func(db *DB) {
db.multiStatements = m
}
}
114 changes: 114 additions & 0 deletions internal/datasource/transaction/delay_transaction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package transaction

import (
"context"
"database/sql"
"fmt"
"sync"

"github.com/ecodeclub/eorm/internal/datasource"
"go.uber.org/multierr"
)

type DelayTxFactory struct{}

func (DelayTxFactory) TxOf(ctx Context, finder datasource.Finder) (datasource.Tx, error) {
return NewDelayTx(ctx, finder), nil
}

type DelayTx struct {
ctx Context
lock sync.RWMutex
txs map[string]datasource.Tx
finder datasource.Finder
}

func (t *DelayTx) findTgt(ctx context.Context, query datasource.Query) (datasource.TxBeginner, error) {
return t.finder.FindTgt(ctx, query)
}

func (t *DelayTx) findOrBeginTx(ctx context.Context, query datasource.Query) (datasource.Tx, error) {
t.lock.RLock()
tx, ok := t.txs[query.DB]
t.lock.RUnlock()
if ok {
return tx, nil
}
t.lock.Lock()
defer t.lock.Unlock()
if tx, ok = t.txs[query.DB]; ok {
return tx, nil
}
var err error
db, err := t.findTgt(ctx, query)
if err != nil {
return nil, err
}
tx, err = db.BeginTx(t.ctx.TxCtx, t.ctx.Opts)
if err != nil {
return nil, err
}
t.txs[query.DB] = tx
return tx, nil
}

func (t *DelayTx) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) {
// 防止 GetMulti 的查询重复创建多个事务
tx, err := t.findOrBeginTx(ctx, query)
if err != nil {
return nil, err
}
return tx.Query(ctx, query)
}

func (t *DelayTx) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) {
tx, err := t.findOrBeginTx(ctx, query)
if err != nil {
return nil, err
}
return tx.Exec(ctx, query)
}

func (t *DelayTx) Commit() error {
var err error
for name, tx := range t.txs {
if er := tx.Commit(); er != nil {
err = multierr.Combine(
err, fmt.Errorf("masterslave DB name [%s] Commit error: %w", name, er))
}
}
return err
}

func (t *DelayTx) Rollback() error {
var err error
for name, tx := range t.txs {
if er := tx.Rollback(); er != nil {
err = multierr.Combine(
err, fmt.Errorf("masterslave DB name [%s] Rollback error: %w", name, er))
}
}
return err
}

func NewDelayTx(ctx Context, finder datasource.Finder) *DelayTx {
return &DelayTx{
ctx: ctx,
finder: finder,
txs: make(map[string]datasource.Tx, 8),
}
}
Loading