diff --git a/cdc/sink/mysql/helper.go b/cdc/sink/mysql/helper.go new file mode 100644 index 00000000000..df1e874a21e --- /dev/null +++ b/cdc/sink/mysql/helper.go @@ -0,0 +1,43 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "context" + "database/sql" + + dmysql "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/log" + tmysql "github.com/pingcap/tidb/parser/mysql" + "go.uber.org/zap" +) + +// CheckIsTiDB checks if the db connects to a TiDB. +func CheckIsTiDB(ctx context.Context, db *sql.DB) (bool, error) { + var tidbVer string + row := db.QueryRowContext(ctx, "select tidb_version()") + err := row.Scan(&tidbVer) + if err != nil { + log.Error("check tidb version error", zap.Error(err)) + // downstream is not TiDB, do nothing + if mysqlErr, ok := errors.Cause(err).(*dmysql.MySQLError); ok && + (mysqlErr.Number == tmysql.ErrNoDB || + mysqlErr.Number == tmysql.ErrSpDoesNotExist) { + return false, nil + } + return false, errors.Trace(err) + } + return true, nil +} diff --git a/cdc/sink/mysql/mysql.go b/cdc/sink/mysql/mysql.go index 4230df2d31f..7c0643a5be6 100644 --- a/cdc/sink/mysql/mysql.go +++ b/cdc/sink/mysql/mysql.go @@ -183,6 +183,13 @@ func NewMySQLSink( "some types of DDL may fail to be executed", zap.String("hostname", hostName), zap.String("port", port)) } + + isTiDB, err := CheckIsTiDB(ctx, testDB) + if err != nil { + return nil, err + } + params.isTiDB = isTiDB + db, err := GetDBConnImpl(ctx, dsnStr) if err != nil { return nil, err @@ -800,6 +807,7 @@ func convert2RowChanges( tableInfo, nil, nil) } + res.SetApproximateDataSize(row.ApproximateDataSize) return res } @@ -876,7 +884,7 @@ func (s *mysqlSink) groupRowsByType( updateRow = append( updateRow, convert2RowChanges(row, tableInfo, sqlmodel.RowChangeUpdate)) - if len(updateRow) >= s.params.maxTxnRow { + if len(updateRow) >= s.params.batchUpdateRowCount { updateRows = append(updateRows, updateRow) updateRow = make([]*sqlmodel.RowChange, 0, s.params.maxTxnRow) } @@ -915,11 +923,22 @@ func (s *mysqlSink) batchSingleTxnDmls( // handle update if len(updateRows) > 0 { // TODO: use sql.GenUpdateSQL to generate update sql after we optimize the func. - for _, rows := range updateRows { - for _, row := range rows { - sql, value := row.GenSQL(sqlmodel.DMLUpdate) - sqls = append(sqls, sql) - values = append(values, value) + if s.params.isTiDB { + for _, rows := range updateRows { + s, v := s.genUpdateSQL(rows...) + sqls = append(sqls, s...) + values = append(values, v...) + } + } else { + // The behavior of batch update statement differs between TiDB and MySQL. + // So we don't use batch update statement when downstream is MySQL. + // Ref:https://docs.pingcap.com/tidb/stable/sql-statement-update#mysql-compatibility + for _, rows := range updateRows { + for _, row := range rows { + sql, value := row.GenSQL(sqlmodel.DMLUpdate) + sqls = append(sqls, sql) + values = append(values, value) + } } } } @@ -1128,6 +1147,28 @@ func (s *mysqlSink) execDMLs(ctx context.Context, txns []*model.SingleTableTxn, return nil } +func (s *mysqlSink) genUpdateSQL(rows ...*sqlmodel.RowChange) ([]string, [][]interface{}) { + size, count := 0, 0 + for _, r := range rows { + size += int(r.GetApproximateDataSize()) + count++ + } + if size < defaultMaxBatchUpdateRowSize*count { + // use batch update + sql, value := sqlmodel.GenUpdateSQLFast(rows...) + return []string{sql}, [][]interface{}{value} + } + // each row has one independent update SQL. + sqls := make([]string, 0, len(rows)) + values := make([][]interface{}, 0, len(rows)) + for _, row := range rows { + sql, value := row.GenSQL(sqlmodel.DMLUpdate) + sqls = append(sqls, sql) + values = append(values, value) + } + return sqls, values +} + // if the column value type is []byte and charset is not binary, we get its string // representation. Because if we use the byte array respresentation, the go-sql-driver // will automatically set `_binary` charset for that column, which is not expected. diff --git a/cdc/sink/mysql/mysql_params.go b/cdc/sink/mysql/mysql_params.go index 9f2eb3b46d7..a68d6d23d2f 100644 --- a/cdc/sink/mysql/mysql_params.go +++ b/cdc/sink/mysql/mysql_params.go @@ -54,6 +54,14 @@ const ( defaultTxnIsolationRC = "READ-COMMITTED" defaultCharacterSet = "utf8mb4" defaultBatchDMLEnable = true + // defaultMaxBatchUpdateRowCount is the default max number of rows in a + // single batch update SQL. + defaultMaxBatchUpdateRowCount = 40 + maxMaxBatchUpdateRowCount = 1024 + // defaultMaxBatchUpdateRowSize(1KB) defines the default value of single row. + // When the average row size is larger defaultMaxBatchUpdateRowSize, + // disable batch update, otherwise enable batch update. + defaultMaxBatchUpdateRowSize = 1024 ) var ( @@ -72,6 +80,7 @@ var defaultParams = &sinkParams{ dialTimeout: defaultDialTimeout, safeMode: defaultSafeMode, batchDMLEnable: defaultBatchDMLEnable, + batchUpdateRowCount: defaultMaxBatchUpdateRowCount, } var validSchemes = map[string]bool{ @@ -97,6 +106,8 @@ type sinkParams struct { timezone string tls string batchDMLEnable bool + batchUpdateRowCount int + isTiDB bool } func (s *sinkParams) Clone() *sinkParams { @@ -266,6 +277,24 @@ func parseSinkURIToParams(ctx context.Context, params.batchDMLEnable = enable } + s = sinkURI.Query().Get("max-multi-update-row") + if s != "" { + c, err := strconv.Atoi(s) + if err != nil { + return nil, cerror.WrapError(cerror.ErrMySQLInvalidConfig, err) + } + if c <= 0 { + return nil, cerror.WrapError(cerror.ErrMySQLInvalidConfig, + fmt.Errorf("invalid max-multi-update-row %d, which must be greater than 0", c)) + } + if c > maxMaxBatchUpdateRowCount { + log.Warn("max-multi-update-row too large", + zap.Int("original", c), zap.Int("override", maxMaxBatchUpdateRowCount)) + c = maxMaxBatchUpdateRowCount + } + params.batchUpdateRowCount = c + } + return params, nil } diff --git a/cdc/sink/mysql/mysql_params_test.go b/cdc/sink/mysql/mysql_params_test.go index b92ac5edc81..c9e983124f7 100644 --- a/cdc/sink/mysql/mysql_params_test.go +++ b/cdc/sink/mysql/mysql_params_test.go @@ -45,6 +45,7 @@ func TestSinkParamsClone(t *testing.T) { dialTimeout: defaultDialTimeout, safeMode: defaultSafeMode, batchDMLEnable: defaultBatchDMLEnable, + batchUpdateRowCount: defaultMaxBatchUpdateRowCount, }, param1) require.Equal(t, &sinkParams{ changefeedID: model.DefaultChangeFeedID("123"), @@ -58,6 +59,7 @@ func TestSinkParamsClone(t *testing.T) { dialTimeout: defaultDialTimeout, safeMode: defaultSafeMode, batchDMLEnable: defaultBatchDMLEnable, + batchUpdateRowCount: defaultMaxBatchUpdateRowCount, }, param2) } @@ -211,9 +213,11 @@ func TestParseSinkURIToParams(t *testing.T) { expected.changefeedID = model.DefaultChangeFeedID("cf-id") expected.captureAddr = "127.0.0.1:8300" expected.tidbTxnMode = "pessimistic" + expected.batchUpdateRowCount = 80 uriStr := "mysql://127.0.0.1:3306/?worker-count=64&max-txn-row=20" + "&batch-replace-enable=true&batch-replace-size=50&safe-mode=false" + - "&tidb-txn-mode=pessimistic" + "&tidb-txn-mode=pessimistic" + + "&max-multi-update-row=80" opts := map[string]string{ metrics.OptCaptureAddr: expected.captureAddr, } @@ -256,22 +260,29 @@ func TestParseSinkURIOverride(t *testing.T) { cases := []struct { uri string checker func(*sinkParams) - }{{ - uri: "mysql://127.0.0.1:3306/?worker-count=2147483648", // int32 max - checker: func(sp *sinkParams) { - require.EqualValues(t, sp.workerCount, maxWorkerCount) + }{ + { + uri: "mysql://127.0.0.1:3306/?worker-count=2147483648", // int32 max + checker: func(sp *sinkParams) { + require.EqualValues(t, sp.workerCount, maxWorkerCount) + }, + }, { + uri: "mysql://127.0.0.1:3306/?max-txn-row=2147483648", // int32 max + checker: func(sp *sinkParams) { + require.EqualValues(t, sp.maxTxnRow, maxMaxTxnRow) + }, + }, { + uri: "mysql://127.0.0.1:3306/?tidb-txn-mode=badmode", + checker: func(sp *sinkParams) { + require.EqualValues(t, sp.tidbTxnMode, defaultTiDBTxnMode) + }, + }, { + uri: "mysql://127.0.0.1:3306/?max-multi-update-row=2147483648", // int32 max + checker: func(sp *sinkParams) { + require.EqualValues(t, sp.batchUpdateRowCount, maxMaxBatchUpdateRowCount) + }, }, - }, { - uri: "mysql://127.0.0.1:3306/?max-txn-row=2147483648", // int32 max - checker: func(sp *sinkParams) { - require.EqualValues(t, sp.maxTxnRow, maxMaxTxnRow) - }, - }, { - uri: "mysql://127.0.0.1:3306/?tidb-txn-mode=badmode", - checker: func(sp *sinkParams) { - require.EqualValues(t, sp.tidbTxnMode, defaultTiDBTxnMode) - }, - }} + } ctx := context.TODO() var uri *url.URL var err error diff --git a/cdc/sink/mysql/mysql_test.go b/cdc/sink/mysql/mysql_test.go index 4eb09f5e19d..00b4e1d6632 100644 --- a/cdc/sink/mysql/mysql_test.go +++ b/cdc/sink/mysql/mysql_test.go @@ -985,6 +985,10 @@ func mockTestDBWithSQLMode(adjustSQLMode bool, sqlMode interface{}) (*sql.DB, er "where character_set_name = 'gbk';").WillReturnRows( sqlmock.NewRows([]string{"character_set_name"}).AddRow("gbk"), ) + mock.ExpectQuery("select tidb_version()").WillReturnError(&dmysql.MySQLError{ + Number: 1305, + Message: "FUNCTION test.tidb_version does not exist", + }) mock.ExpectClose() return db, nil diff --git a/pkg/applier/redo_test.go b/pkg/applier/redo_test.go index 1329f63c1fb..05b785af72e 100644 --- a/pkg/applier/redo_test.go +++ b/pkg/applier/redo_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/DATA-DOG/go-sqlmock" + dmysql "github.com/go-sql-driver/mysql" "github.com/phayes/freeport" "github.com/pingcap/tiflow/cdc/model" "github.com/pingcap/tiflow/cdc/redo/common" @@ -149,6 +150,10 @@ func TestApplyDMLs(t *testing.T) { "where character_set_name = 'gbk';").WillReturnRows( sqlmock.NewRows([]string{"character_set_name"}).AddRow("gbk"), ) + mock.ExpectQuery("select tidb_version()").WillReturnError(&dmysql.MySQLError{ + Number: 1305, + Message: "FUNCTION test.tidb_version does not exist", + }) mock.ExpectClose() return db, nil } diff --git a/pkg/sqlmodel/multivalue.go b/pkg/sqlmodel/multirow.go similarity index 67% rename from pkg/sqlmodel/multivalue.go rename to pkg/sqlmodel/multirow.go index b1dc05ab381..8e61ff1d942 100644 --- a/pkg/sqlmodel/multivalue.go +++ b/pkg/sqlmodel/multirow.go @@ -1,4 +1,4 @@ -// Copyright 2022 PingCAP, Inc. +// Copyright 2023 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,16 +16,14 @@ package sqlmodel import ( "strings" + "github.com/pingcap/log" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/format" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/opcode" driver "github.com/pingcap/tidb/types/parser_driver" - - "go.uber.org/zap" - - "github.com/pingcap/tiflow/dm/pkg/log" "github.com/pingcap/tiflow/pkg/quotes" + "go.uber.org/zap" ) // SameTypeTargetAndColumns check whether two row changes have same type, target @@ -309,7 +307,7 @@ func GenUpdateSQL(changes ...*RowChange) (string, []interface{}) { // a simple check about different number of WHERE values, not trying to // cover all cases if len(whereValues) != len(whereColumns) { - log.L().DPanic("len(whereValues) != len(whereColumns)", + log.Panic("len(whereValues) != len(whereColumns)", zap.Int("len(whereValues)", len(whereValues)), zap.Int("len(whereColumns)", len(whereColumns)), zap.Any("whereValues", whereValues), @@ -346,3 +344,149 @@ func GenUpdateSQL(changes ...*RowChange) (string, []interface{}) { } return buf.String(), args } + +// GenUpdateSQLFast generates the UPDATE SQL and its arguments. +// Input `changes` should have same target table and same columns for WHERE +// (typically same PK/NOT NULL UK), otherwise the behaviour is undefined. +// It is a faster version compared with GenUpdateSQL. +func GenUpdateSQLFast(changes ...*RowChange) (string, []any) { + if len(changes) == 0 { + log.L().DPanic("row changes is empty") + return "", nil + } + var buf strings.Builder + buf.Grow(1024) + + // Generate UPDATE `db`.`table` SET + first := changes[0] + buf.WriteString("UPDATE ") + buf.WriteString(first.targetTable.QuoteString()) + buf.WriteString(" SET ") + + // Pre-generate essential sub statements used after WHEN, WHERE and IN. + var ( + whereCaseStmt string + whenCaseStmt string + inCaseStmt string + ) + whereColumns, _ := first.whereColumnsAndValues() + if len(whereColumns) == 1 { + // one field PK or UK, use `field`=? directly. + whereCaseStmt = quotes.QuoteName(whereColumns[0]) + whenCaseStmt = whereCaseStmt + "=?" + inCaseStmt = valuesHolder(len(changes)) + } else { + // multiple fields PK or UK, use ROW(...fields) expression. + whereValuesHolder := valuesHolder(len(whereColumns)) + whereCaseStmt = "ROW(" + for i, column := range whereColumns { + whereCaseStmt += quotes.QuoteName(column) + if i != len(whereColumns)-1 { + whereCaseStmt += "," + } else { + whereCaseStmt += ")" + whenCaseStmt = whereCaseStmt + "=ROW" + whereValuesHolder + } + } + var inCaseStmtBuf strings.Builder + // inCaseStmt sample: IN (ROW(?,?,?),ROW(?,?,?)) + // ^ ^ + // Buffer size count between |---------------------| + // equals to 3 * len(changes) for each `ROW` + // plus 1 * len(changes) - 1 for each `,` between every two ROW(?,?,?) + // plus len(whereValuesHolder) * len(changes) + // plus 2 for `(` and `)` + inCaseStmtBuf.Grow((4+len(whereValuesHolder))*len(changes) + 1) + inCaseStmtBuf.WriteString("(") + for i := range changes { + inCaseStmtBuf.WriteString("ROW") + inCaseStmtBuf.WriteString(whereValuesHolder) + if i != len(changes)-1 { + inCaseStmtBuf.WriteString(",") + } else { + inCaseStmtBuf.WriteString(")") + } + } + inCaseStmt = inCaseStmtBuf.String() + } + + // Generate `ColumnName`=CASE WHEN .. THEN .. END + // Use this value in order to identify which is the first CaseWhenThen line, + // because generated column can happen any where and it will be skipped. + isFirstCaseWhenThenLine := true + for _, column := range first.targetTableInfo.Columns { + if isGenerated(first.targetTableInfo.Columns, column.Name) { + continue + } + if !isFirstCaseWhenThenLine { + // insert ", " after END of each lines except for the first line. + buf.WriteString(", ") + } + + buf.WriteString(quotes.QuoteName(column.Name.String()) + "=CASE") + for range changes { + buf.WriteString(" WHEN ") + buf.WriteString(whenCaseStmt) + buf.WriteString(" THEN ?") + } + buf.WriteString(" END") + isFirstCaseWhenThenLine = false + } + + // Generate WHERE .. IN .. + buf.WriteString(" WHERE ") + buf.WriteString(whereCaseStmt) + buf.WriteString(" IN ") + buf.WriteString(inCaseStmt) + + // Build args of the UPDATE SQL + var assignValueColumnCount int + var skipColIdx []int + for i, col := range first.sourceTableInfo.Columns { + if isGenerated(first.targetTableInfo.Columns, col.Name) { + skipColIdx = append(skipColIdx, i) + continue + } + assignValueColumnCount++ + } + args := make([]any, 0, + assignValueColumnCount*len(changes)*(len(whereColumns)+1)+len(changes)*len(whereColumns)) + argsPerCol := make([][]any, assignValueColumnCount) + for i := 0; i < assignValueColumnCount; i++ { + argsPerCol[i] = make([]any, 0, len(changes)*(len(whereColumns)+1)) + } + whereValuesAtTheEnd := make([]any, 0, len(changes)*len(whereColumns)) + for _, change := range changes { + _, whereValues := change.whereColumnsAndValues() + // a simple check about different number of WHERE values, not trying to + // cover all cases + if len(whereValues) != len(whereColumns) { + log.Panic("len(whereValues) != len(whereColumns)", + zap.Int("len(whereValues)", len(whereValues)), + zap.Int("len(whereColumns)", len(whereColumns)), + zap.Any("whereValues", whereValues), + zap.Stringer("sourceTable", change.sourceTable)) + return "", nil + } + + whereValuesAtTheEnd = append(whereValuesAtTheEnd, whereValues...) + + i := 0 // used as index of skipColIdx + writeableCol := 0 + for j, val := range change.postValues { + if i < len(skipColIdx) && skipColIdx[i] == j { + i++ + continue + } + argsPerCol[writeableCol] = append(argsPerCol[writeableCol], whereValues...) + argsPerCol[writeableCol] = append(argsPerCol[writeableCol], val) + writeableCol++ + } + } + for _, a := range argsPerCol { + args = append(args, a...) + } + args = append(args, whereValuesAtTheEnd...) + + return buf.String(), args +} diff --git a/pkg/sqlmodel/multirow_bench_test.go b/pkg/sqlmodel/multirow_bench_test.go new file mode 100644 index 00000000000..496758274b3 --- /dev/null +++ b/pkg/sqlmodel/multirow_bench_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "fmt" + "testing" + "time" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" +) + +func prepareDataOneColoumnPK(t *testing.T, batch int) []*RowChange { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb (c INT, c2 INT, c3 INT, + c4 VARCHAR(10), c5 VARCHAR(100), c6 VARCHAR(1000), PRIMARY KEY (c))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c INT, c2 INT, c3 INT, + c4 VARCHAR(10), c5 VARCHAR(100), c6 VARCHAR(1000), PRIMARY KEY (c))`) + + changes := make([]*RowChange, 0, batch) + for i := 0; i < batch; i++ { + change := NewRowChange(source, target, + []interface{}{i + 1, i + 2, i + 3, "c4", "c5", "c6"}, + []interface{}{i + 10, i + 20, i + 30, "c4", "c5", "c6"}, + sourceTI, targetTI, nil) + changes = append(changes, change) + } + return changes +} + +func prepareDataMultiColumnsPK(t *testing.T, batch int) []*RowChange { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb (c1 INT, c2 INT, c3 INT, c4 INT, + c5 VARCHAR(10), c6 VARCHAR(100), c7 VARCHAR(1000), c8 timestamp, c9 timestamp, + PRIMARY KEY (c1, c2, c3, c4))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c1 INT, c2 INT, c3 INT, c4 INT, + c5 VARCHAR(10), c6 VARCHAR(100), c7 VARCHAR(1000), c8 timestamp, c9 timestamp, + PRIMARY KEY (c1, c2, c3, c4))`) + + changes := make([]*RowChange, 0, batch) + for i := 0; i < batch; i++ { + change := NewRowChange(source, target, + []interface{}{i + 1, i + 2, i + 3, i + 4, "c4", "c5", "c6", "c7", time.Time{}, time.Time{}}, + []interface{}{i + 10, i + 20, i + 30, i + 40, "c4", "c5", "c6", "c7", time.Time{}, time.Time{}}, + sourceTI, targetTI, nil) + changes = append(changes, change) + } + return changes +} + +// bench cmd: go test -run='^$' -benchmem -bench '^(BenchmarkGenUpdate)$' github.com/pingcap/tiflow/pkg/sqlmodel +func BenchmarkGenUpdate(b *testing.B) { + t := &testing.T{} + type genCase struct { + name string + fn genSQLFunc + prepare func(t *testing.T, batch int) []*RowChange + } + batchSizes := []int{ + 1, 2, 4, 8, 16, 32, 64, 128, + } + benchCases := []genCase{ + { + name: "OneColumnPK-GenUpdateSQL", + fn: GenUpdateSQL, + prepare: prepareDataOneColoumnPK, + }, + { + name: "OneColumnPK-GenUpdateSQLFast", + fn: GenUpdateSQLFast, + prepare: prepareDataOneColoumnPK, + }, + { + name: "MultiColumnsPK-GenUpdateSQL", + fn: GenUpdateSQL, + prepare: prepareDataMultiColumnsPK, + }, + { + name: "MultiColumnsPK-GenUpdateSQLFast", + fn: GenUpdateSQLFast, + prepare: prepareDataMultiColumnsPK, + }, + } + for _, bc := range benchCases { + for _, batch := range batchSizes { + name := fmt.Sprintf("%s-Batch%d", bc.name, batch) + b.Run(name, func(b *testing.B) { + changes := prepareDataOneColoumnPK(t, batch) + for i := 0; i < b.N; i++ { + bc.fn(changes...) + } + }) + } + } +} diff --git a/pkg/sqlmodel/multirow_test.go b/pkg/sqlmodel/multirow_test.go new file mode 100644 index 00000000000..3fd84ec9a2e --- /dev/null +++ b/pkg/sqlmodel/multirow_test.go @@ -0,0 +1,240 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "testing" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" + "github.com/stretchr/testify/require" +) + +type genSQLFunc func(changes ...*RowChange) (string, []interface{}) + +func TestGenDeleteMultiRows(t *testing.T) { + t.Parallel() + + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT)") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT PRIMARY KEY, c2 INT)") + + change1 := NewRowChange(source1, target, []interface{}{1, 2}, nil, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, []interface{}{3, 4}, nil, sourceTI2, targetTI, nil) + sql, args := GenDeleteSQL(change1, change2) + + require.Equal(t, "DELETE FROM `db`.`tb` WHERE (`c`) IN ((?),(?))", sql) + require.Equal(t, []interface{}{1, 3}, args) +} + +func TestGenUpdateMultiRows(t *testing.T) { + t.Parallel() + testGenUpdateMultiRows(t, GenUpdateSQL) + testGenUpdateMultiRows(t, GenUpdateSQLFast) +} + +func TestGenUpdateMultiRowsOneColPK(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsOneColPK(t, GenUpdateSQL) + testGenUpdateMultiRowsOneColPK(t, GenUpdateSQLFast) +} + +func TestGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsWithVirtualGeneratedColumn(t, GenUpdateSQL) + testGenUpdateMultiRowsWithVirtualGeneratedColumn(t, GenUpdateSQLFast) + testGenUpdateMultiRowsWithVirtualGeneratedColumns(t, GenUpdateSQL) + testGenUpdateMultiRowsWithVirtualGeneratedColumns(t, GenUpdateSQLFast) +} + +func TestGenUpdateMultiRowsWithStoredGeneratedColumn(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsWithStoredGeneratedColumn(t, GenUpdateSQL) + testGenUpdateMultiRowsWithStoredGeneratedColumn(t, GenUpdateSQLFast) +} + +func testGenUpdateMultiRows(t *testing.T, genUpdate genSQLFunc) { + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c2 INT, c3 INT, PRIMARY KEY (c, c2))") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT, c2 INT, c3 INT, PRIMARY KEY (c, c2))") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT, c2 INT, c3 INT, PRIMARY KEY (c, c2))") + + change1 := NewRowChange(source1, target, []interface{}{1, 2, 3}, []interface{}{10, 20, 30}, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, []interface{}{4, 5, 6}, []interface{}{40, 50, 60}, sourceTI2, targetTI, nil) + sql, args := genUpdate(change1, change2) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c`=CASE WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? END, " + + "`c2`=CASE WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? END, " + + "`c3`=CASE WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? END " + + "WHERE ROW(`c`,`c2`) IN (ROW(?,?),ROW(?,?))" + expectedArgs := []interface{}{ + 1, 2, 10, 4, 5, 40, + 1, 2, 20, 4, 5, 50, + 1, 2, 30, 4, 5, 60, + 1, 2, 4, 5, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func testGenUpdateMultiRowsOneColPK(t *testing.T, genUpdate genSQLFunc) { + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c2 INT, c3 INT, PRIMARY KEY (c))") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT, c2 INT, c3 INT, PRIMARY KEY (c))") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT, c2 INT, c3 INT, PRIMARY KEY (c))") + + change1 := NewRowChange(source1, target, []interface{}{1, 2, 3}, []interface{}{10, 20, 30}, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, []interface{}{4, 5, 6}, []interface{}{40, 50, 60}, sourceTI2, targetTI, nil) + sql, args := genUpdate(change1, change2) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + + "`c2`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + + "`c3`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? END " + + "WHERE `c` IN (?,?)" + expectedArgs := []interface{}{ + 1, 10, 4, 40, + 1, 20, 4, 50, + 1, 30, 4, 60, + 1, 4, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func testGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T, genUpdate genSQLFunc) { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c1 int as (c+100) virtual not null, c2 INT, c3 INT, PRIMARY KEY (c))") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT, c1 int as (c+100) virtual not null, c2 INT, c3 INT, PRIMARY KEY (c))") + + change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3}, []interface{}{10, 110, 20, 30}, sourceTI, targetTI, nil) + change2 := NewRowChange(source, target, []interface{}{4, 104, 5, 6}, []interface{}{40, 140, 50, 60}, sourceTI, targetTI, nil) + change3 := NewRowChange(source, target, []interface{}{7, 107, 8, 9}, []interface{}{70, 170, 80, 90}, sourceTI, targetTI, nil) + sql, args := genUpdate(change1, change2, change3) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + + "`c2`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + + "`c3`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? WHEN `c`=? THEN ? END " + + "WHERE `c` IN (?,?,?)" + expectedArgs := []interface{}{ + 1, 10, 4, 40, 7, 70, + 1, 20, 4, 50, 7, 80, + 1, 30, 4, 60, 7, 90, + 1, 4, 7, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +// multiple generated columns test case +func testGenUpdateMultiRowsWithVirtualGeneratedColumns(t *testing.T, genUpdate genSQLFunc) { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb1 (c0 int as (c4*c4) virtual not null, + c1 int as (c+100) virtual not null, c2 INT, c3 INT, c4 INT, PRIMARY KEY (c4))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c0 int as (c4*c4) virtual not null, + c1 int as (c+100) virtual not null, c2 INT, c3 INT, c4 INT, PRIMARY KEY (c4))`) + + change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3, 1}, []interface{}{100, 110, 20, 30, 10}, sourceTI, targetTI, nil) + change2 := NewRowChange(source, target, []interface{}{16, 104, 5, 6, 4}, []interface{}{1600, 140, 50, 60, 40}, sourceTI, targetTI, nil) + change3 := NewRowChange(source, target, []interface{}{49, 107, 8, 9, 7}, []interface{}{4900, 170, 80, 90, 70}, sourceTI, targetTI, nil) + sql, args := genUpdate(change1, change2, change3) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c2`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END, " + + "`c3`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END, " + + "`c4`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END " + + "WHERE `c4` IN (?,?,?)" + expectedArgs := []interface{}{ + 1, 20, 4, 50, 7, 80, + 1, 30, 4, 60, 7, 90, + 1, 10, 4, 40, 7, 70, + 1, 4, 7, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func testGenUpdateMultiRowsWithStoredGeneratedColumn(t *testing.T, genUpdate genSQLFunc) { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c1 int as (c+100) stored, c2 INT, c3 INT, PRIMARY KEY (c1))") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT, c1 int as (c+100) stored, c2 INT, c3 INT, PRIMARY KEY (c1))") + + change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3}, []interface{}{10, 110, 20, 30}, sourceTI, targetTI, nil) + change2 := NewRowChange(source, target, []interface{}{4, 104, 5, 6}, []interface{}{40, 140, 50, 60}, sourceTI, targetTI, nil) + change3 := NewRowChange(source, target, []interface{}{7, 107, 8, 9}, []interface{}{70, 170, 80, 90}, sourceTI, targetTI, nil) + sql, args := genUpdate(change1, change2, change3) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c`=CASE WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? END, " + + "`c2`=CASE WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? END, " + + "`c3`=CASE WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? END " + + "WHERE `c1` IN (?,?,?)" + expectedArgs := []interface{}{ + 101, 10, 104, 40, 107, 70, + 101, 20, 104, 50, 107, 80, + 101, 30, 104, 60, 107, 90, + 101, 104, 107, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func TestGenInsertMultiRows(t *testing.T) { + t.Parallel() + + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + targetTI := mockTableInfo(t, "CREATE TABLE tb (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + + change1 := NewRowChange(source1, target, nil, []interface{}{2, 1, 2}, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, nil, []interface{}{4, 3, 4}, sourceTI2, targetTI, nil) + + sql, args := GenInsertSQL(DMLInsert, change1, change2) + require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) + + sql, args = GenInsertSQL(DMLReplace, change1, change2) + require.Equal(t, "REPLACE INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) + + sql, args = GenInsertSQL(DMLInsertOnDuplicateUpdate, change1, change2) + require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?) ON DUPLICATE KEY UPDATE `c`=VALUES(`c`),`c2`=VALUES(`c2`)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) +} diff --git a/pkg/sqlmodel/multivalue_test.go b/pkg/sqlmodel/multivalue_test.go deleted file mode 100644 index a06326d5ee9..00000000000 --- a/pkg/sqlmodel/multivalue_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlmodel - -import ( - "testing" - - "github.com/stretchr/testify/require" - - cdcmodel "github.com/pingcap/tiflow/cdc/model" -) - -func TestGenDeleteMultiValue(t *testing.T) { - t.Parallel() - - source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} - source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} - target := &cdcmodel.TableName{Schema: "db", Table: "tb"} - - sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)") - sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT)") - targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT PRIMARY KEY, c2 INT)") - - change1 := NewRowChange(source1, target, []interface{}{1, 2}, nil, sourceTI1, targetTI, nil) - change2 := NewRowChange(source2, target, []interface{}{3, 4}, nil, sourceTI2, targetTI, nil) - sql, args := GenDeleteSQL(change1, change2) - - require.Equal(t, "DELETE FROM `db`.`tb` WHERE (`c`) IN ((?),(?))", sql) - require.Equal(t, []interface{}{1, 3}, args) -} - -func TestGenInsertMultiValue(t *testing.T) { - t.Parallel() - - source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} - source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} - target := &cdcmodel.TableName{Schema: "db", Table: "tb"} - - sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") - sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") - targetTI := mockTableInfo(t, "CREATE TABLE tb (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") - - change1 := NewRowChange(source1, target, nil, []interface{}{2, 1, 2}, sourceTI1, targetTI, nil) - change2 := NewRowChange(source2, target, nil, []interface{}{4, 3, 4}, sourceTI2, targetTI, nil) - - sql, args := GenInsertSQL(DMLInsert, change1, change2) - require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) - require.Equal(t, []interface{}{1, 2, 3, 4}, args) - - sql, args = GenInsertSQL(DMLReplace, change1, change2) - require.Equal(t, "REPLACE INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) - require.Equal(t, []interface{}{1, 2, 3, 4}, args) - - sql, args = GenInsertSQL(DMLInsertOnDuplicateUpdate, change1, change2) - require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?) ON DUPLICATE KEY UPDATE `c`=VALUES(`c`),`c2`=VALUES(`c2`)", sql) - require.Equal(t, []interface{}{1, 2, 3, 4}, args) -} diff --git a/pkg/sqlmodel/row_change.go b/pkg/sqlmodel/row_change.go index 20d5a8767a1..51e3481c54c 100644 --- a/pkg/sqlmodel/row_change.go +++ b/pkg/sqlmodel/row_change.go @@ -70,6 +70,8 @@ type RowChange struct { tp RowChangeType whereHandle *WhereHandle + + approximateDataSize int64 } // NewRowChange creates a new RowChange. @@ -196,6 +198,17 @@ func (r *RowChange) SetWhereHandle(whereHandle *WhereHandle) { r.whereHandle = whereHandle } +// GetApproximateDataSize returns internal approximateDataSize, it could be zero +// if this value is not set. +func (r *RowChange) GetApproximateDataSize() int64 { + return r.approximateDataSize +} + +// SetApproximateDataSize sets the approximate size of row change. +func (r *RowChange) SetApproximateDataSize(approximateDataSize int64) { + r.approximateDataSize = approximateDataSize +} + func (r *RowChange) lazyInitWhereHandle() { if r.whereHandle != nil { return