Skip to content

Commit

Permalink
Use sqlparser for all dynamic query building in VDiff2 (#13319)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattlord authored Jun 18, 2023
1 parent fd6ace9 commit 52dcb8f
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 61 deletions.
67 changes: 55 additions & 12 deletions go/vt/vttablet/tabletmanager/vdiff/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/google/uuid"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/vterrors"

Expand Down Expand Up @@ -97,7 +98,10 @@ func (vde *Engine) getVDiffSummary(vdiffID int64, dbClient binlogplayer.DBClient
var qr *sqltypes.Result
var err error

query := fmt.Sprintf(sqlVDiffSummary, vdiffID)
query, err := sqlparser.ParseAndBind(sqlVDiffSummary, sqltypes.Int64BindVariable(vdiffID))
if err != nil {
return nil, err
}
if qr, err = dbClient.ExecuteFetch(query, -1); err != nil {
return nil, err
}
Expand Down Expand Up @@ -144,10 +148,12 @@ func (vde *Engine) getDefaultCell() (string, error) {

func (vde *Engine) handleCreateResumeAction(ctx context.Context, dbClient binlogplayer.DBClient, action VDiffAction, req *tabletmanagerdatapb.VDiffRequest, resp *tabletmanagerdatapb.VDiffResponse) error {
var qr *sqltypes.Result
var err error
options := req.Options

query := fmt.Sprintf(sqlGetVDiffID, encodeString(req.VdiffUuid))
query, err := sqlparser.ParseAndBind(sqlGetVDiffID, sqltypes.StringBindVariable(req.VdiffUuid))
if err != nil {
return err
}
if qr, err = dbClient.ExecuteFetch(query, 1); err != nil {
return err
}
Expand All @@ -173,9 +179,18 @@ func (vde *Engine) handleCreateResumeAction(ctx context.Context, dbClient binlog
return err
}
if action == CreateAction {
query := fmt.Sprintf(sqlNewVDiff,
encodeString(req.Keyspace), encodeString(req.Workflow), "pending", encodeString(string(optionsJSON)),
vde.thisTablet.Shard, topoproto.TabletDbName(vde.thisTablet), req.VdiffUuid)
query, err := sqlparser.ParseAndBind(sqlNewVDiff,
sqltypes.StringBindVariable(req.Keyspace),
sqltypes.StringBindVariable(req.Workflow),
sqltypes.StringBindVariable("pending"),
sqltypes.StringBindVariable(string(optionsJSON)),
sqltypes.StringBindVariable(vde.thisTablet.Shard),
sqltypes.StringBindVariable(topoproto.TabletDbName(vde.thisTablet)),
sqltypes.StringBindVariable(req.VdiffUuid),
)
if err != nil {
return err
}
if qr, err = dbClient.ExecuteFetch(query, 1); err != nil {
return err
}
Expand All @@ -185,7 +200,13 @@ func (vde *Engine) handleCreateResumeAction(ctx context.Context, dbClient binlog
}
resp.Id = int64(qr.InsertID)
} else {
query := fmt.Sprintf(sqlResumeVDiff, encodeString(string(optionsJSON)), encodeString(req.VdiffUuid))
query, err := sqlparser.ParseAndBind(sqlResumeVDiff,
sqltypes.StringBindVariable(string(optionsJSON)),
sqltypes.StringBindVariable(req.VdiffUuid),
)
if err != nil {
return err
}
if qr, err = dbClient.ExecuteFetch(query, 1); err != nil {
return err
}
Expand Down Expand Up @@ -219,7 +240,13 @@ func (vde *Engine) handleShowAction(ctx context.Context, dbClient binlogplayer.D
vdiffUUID := ""

if req.ActionArg == LastActionArg {
query := fmt.Sprintf(sqlGetMostRecentVDiff, encodeString(req.Keyspace), encodeString(req.Workflow))
query, err := sqlparser.ParseAndBind(sqlGetMostRecentVDiff,
sqltypes.StringBindVariable(req.Keyspace),
sqltypes.StringBindVariable(req.Workflow),
)
if err != nil {
return err
}
if qr, err = dbClient.ExecuteFetch(query, 1); err != nil {
return err
}
Expand All @@ -234,7 +261,14 @@ func (vde *Engine) handleShowAction(ctx context.Context, dbClient binlogplayer.D
}
if vdiffUUID != "" {
resp.VdiffUuid = vdiffUUID
query := fmt.Sprintf(sqlGetVDiffByKeyspaceWorkflowUUID, encodeString(req.Keyspace), encodeString(req.Workflow), encodeString(vdiffUUID))
query, err := sqlparser.ParseAndBind(sqlGetVDiffByKeyspaceWorkflowUUID,
sqltypes.StringBindVariable(req.Keyspace),
sqltypes.StringBindVariable(req.Workflow),
sqltypes.StringBindVariable(vdiffUUID),
)
if err != nil {
return err
}
if qr, err = dbClient.ExecuteFetch(query, 1); err != nil {
return err
}
Expand Down Expand Up @@ -278,7 +312,7 @@ func (vde *Engine) handleStopAction(ctx context.Context, dbClient binlogplayer.D
if controller.uuid == req.VdiffUuid {
controller.Stop()
if err := controller.markStoppedByRequest(); err != nil {
return err
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "encountered an error marking vdiff %s as stopped: %v", controller.uuid, err)
}
break
}
Expand All @@ -292,13 +326,22 @@ func (vde *Engine) handleDeleteAction(ctx context.Context, dbClient binlogplayer

switch req.ActionArg {
case AllActionArg:
query = fmt.Sprintf(sqlDeleteVDiffs, encodeString(req.Keyspace), encodeString(req.Workflow))
query, err = sqlparser.ParseAndBind(sqlDeleteVDiffs,
sqltypes.StringBindVariable(req.Keyspace),
sqltypes.StringBindVariable(req.Workflow),
)
if err != nil {
return err
}
default:
uuid, err := uuid.Parse(req.ActionArg)
if err != nil {
return fmt.Errorf("action argument %s not supported", req.ActionArg)
}
query = fmt.Sprintf(sqlDeleteVDiffByUUID, encodeString(uuid.String()))
query, err = sqlparser.ParseAndBind(sqlDeleteVDiffByUUID, sqltypes.StringBindVariable(uuid.String()))
if err != nil {
return err
}
}
if _, err = dbClient.ExecuteFetch(query, 1); err != nil {
return err
Expand Down
27 changes: 18 additions & 9 deletions go/vt/vttablet/tabletmanager/vdiff/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"time"

"vitess.io/vitess/go/vt/proto/tabletmanagerdata"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"

"google.golang.org/protobuf/encoding/prototext"
Expand Down Expand Up @@ -165,8 +166,13 @@ func (ct *controller) updateState(dbClient binlogplayer.DBClient, state VDiffSta
// Clear out any previous error for the vdiff on this shard
err = errors.New("")
}
query := fmt.Sprintf(sqlUpdateVDiffState, encodeString(string(state)), encodeString(err.Error()), extraCols, ct.id)
if _, err := dbClient.ExecuteFetch(query, 1); err != nil {
query := sqlparser.BuildParsedQuery(sqlUpdateVDiffState,
encodeString(string(state)),
encodeString(err.Error()),
extraCols,
ct.id,
)
if _, err := dbClient.ExecuteFetch(query.Query, 1); err != nil {
return err
}
insertVDiffLog(ct.vde.ctx, dbClient, ct.id, fmt.Sprintf("State changed to: %s", state))
Expand All @@ -179,9 +185,10 @@ func (ct *controller) start(ctx context.Context, dbClient binlogplayer.DBClient)
return vterrors.Errorf(vtrpcpb.Code_CANCELED, "context has expired")
default:
}
ct.workflowFilter = fmt.Sprintf("where workflow = %s and db_name = %s", encodeString(ct.workflow), encodeString(ct.vde.dbName))
query := fmt.Sprintf(sqlGetVReplicationEntry, ct.workflowFilter)
qr, err := dbClient.ExecuteFetch(query, -1)
ct.workflowFilter = fmt.Sprintf("where workflow = %s and db_name = %s", encodeString(ct.workflow),
encodeString(ct.vde.dbName))
query := sqlparser.BuildParsedQuery(sqlGetVReplicationEntry, ct.workflowFilter)
qr, err := dbClient.ExecuteFetch(query.Query, -1)
if err != nil {
return err
}
Expand Down Expand Up @@ -248,15 +255,17 @@ func (ct *controller) start(ctx context.Context, dbClient binlogplayer.DBClient)
func (ct *controller) markStoppedByRequest() error {
dbClient := ct.vde.dbClientFactoryFiltered()
if err := dbClient.Connect(); err != nil {
return fmt.Errorf("encountered an error marking vdiff %s as stopped: %v", ct.uuid, err)
return err
}
defer dbClient.Close()

query := fmt.Sprintf(sqlUpdateVDiffStopped, ct.id)
query, err := sqlparser.ParseAndBind(sqlUpdateVDiffStopped, sqltypes.Int64BindVariable(ct.id))
if err != nil {
return err
}
var res *sqltypes.Result
var err error
if res, err = dbClient.ExecuteFetch(query, 1); err != nil {
return fmt.Errorf("encountered an error marking vdiff %s as stopped: %v", ct.uuid, err)
return err
}
// We don't mark it as stopped if it's already completed
if res.RowsAffected > 0 {
Expand Down
13 changes: 11 additions & 2 deletions go/vt/vttablet/tabletmanager/vdiff/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/proto/tabletmanagerdata"
"vitess.io/vitess/go/vt/proto/topodata"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vttablet/tabletmanager/vreplication"
"vitess.io/vitess/go/vt/vttablet/tmclient"

Expand Down Expand Up @@ -297,7 +298,11 @@ func (vde *Engine) getVDiffsToRetry(ctx context.Context, dbClient binlogplayer.D
}

func (vde *Engine) getVDiffByID(ctx context.Context, dbClient binlogplayer.DBClient, id int64) (*sqltypes.Result, error) {
qr, err := dbClient.ExecuteFetch(fmt.Sprintf(sqlGetVDiffByID, id), -1)
query, err := sqlparser.ParseAndBind(sqlGetVDiffByID, sqltypes.Int64BindVariable(id))
if err != nil {
return nil, err
}
qr, err := dbClient.ExecuteFetch(query, -1)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -340,7 +345,11 @@ func (vde *Engine) retryVDiffs(ctx context.Context) error {
return err
}
log.Infof("Retrying vdiff %s that had an ephemeral error of '%v'", uuid, lastError)
if _, err = dbClient.ExecuteFetch(fmt.Sprintf(sqlRetryVDiff, id), 1); err != nil {
query, err := sqlparser.ParseAndBind(sqlRetryVDiff, sqltypes.Int64BindVariable(id))
if err != nil {
return err
}
if _, err = dbClient.ExecuteFetch(query, 1); err != nil {
return err
}
options := &tabletmanagerdata.VDiffOptions{}
Expand Down
40 changes: 20 additions & 20 deletions go/vt/vttablet/tabletmanager/vdiff/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,48 @@ package vdiff

const (
sqlAnalyzeTable = "analyze table `%s`.`%s`"
sqlNewVDiff = "insert into _vt.vdiff(keyspace, workflow, state, options, shard, db_name, vdiff_uuid) values(%s, %s, '%s', %s, '%s', '%s', '%s')"
sqlResumeVDiff = `update _vt.vdiff as vd, _vt.vdiff_table as vdt set vd.options = %s, vd.started_at = NULL, vd.completed_at = NULL, vd.state = 'pending',
vdt.state = 'pending' where vd.vdiff_uuid = %s and vd.id = vdt.vdiff_id and vd.state in ('completed', 'stopped')
sqlNewVDiff = "insert into _vt.vdiff(keyspace, workflow, state, options, shard, db_name, vdiff_uuid) values(%a, %a, %a, %a, %a, %a, %a)"
sqlResumeVDiff = `update _vt.vdiff as vd, _vt.vdiff_table as vdt set vd.options = %a, vd.started_at = NULL, vd.completed_at = NULL, vd.state = 'pending',
vdt.state = 'pending' where vd.vdiff_uuid = %a and vd.id = vdt.vdiff_id and vd.state in ('completed', 'stopped')
and vdt.state in ('completed', 'stopped')`
sqlRetryVDiff = `update _vt.vdiff as vd left join _vt.vdiff_table as vdt on (vd.id = vdt.vdiff_id) set vd.state = 'pending',
vd.last_error = '', vdt.state = 'pending' where vd.id = %d and (vd.state = 'error' or vdt.state = 'error')`
sqlGetVDiffByKeyspaceWorkflowUUID = "select * from _vt.vdiff where keyspace = %s and workflow = %s and vdiff_uuid = %s"
sqlGetMostRecentVDiff = "select * from _vt.vdiff where keyspace = %s and workflow = %s order by id desc limit 1"
sqlGetVDiffByID = "select * from _vt.vdiff where id = %d"
vd.last_error = '', vdt.state = 'pending' where vd.id = %a and (vd.state = 'error' or vdt.state = 'error')`
sqlGetVDiffByKeyspaceWorkflowUUID = "select * from _vt.vdiff where keyspace = %a and workflow = %a and vdiff_uuid = %a"
sqlGetMostRecentVDiff = "select * from _vt.vdiff where keyspace = %a and workflow = %a order by id desc limit 1"
sqlGetVDiffByID = "select * from _vt.vdiff where id = %a"
sqlDeleteVDiffs = `delete from vd, vdt, vdl using _vt.vdiff as vd left join _vt.vdiff_table as vdt on (vd.id = vdt.vdiff_id)
left join _vt.vdiff_log as vdl on (vd.id = vdl.vdiff_id)
where vd.keyspace = %s and vd.workflow = %s`
where vd.keyspace = %a and vd.workflow = %a`
sqlDeleteVDiffByUUID = `delete from vd, vdt using _vt.vdiff as vd left join _vt.vdiff_table as vdt on (vd.id = vdt.vdiff_id)
where vd.vdiff_uuid = %s`
where vd.vdiff_uuid = %a`
sqlVDiffSummary = `select vd.state as vdiff_state, vd.last_error as last_error, vdt.table_name as table_name,
vd.vdiff_uuid as 'uuid', vdt.state as table_state, vdt.table_rows as table_rows,
vd.started_at as started_at, vdt.rows_compared as rows_compared, vd.completed_at as completed_at,
IF(vdt.mismatch = 1, 1, 0) as has_mismatch, vdt.report as report
from _vt.vdiff as vd left join _vt.vdiff_table as vdt on (vd.id = vdt.vdiff_id)
where vd.id = %d`
where vd.id = %a`
// sqlUpdateVDiffState has a penultimate placeholder for any additional columns you want to update, e.g. `, foo = 1`
sqlUpdateVDiffState = "update _vt.vdiff set state = %s, last_error = %s %s where id = %d"
sqlUpdateVDiffStopped = `update _vt.vdiff as vd, _vt.vdiff_table as vdt set vd.state = 'stopped', vdt.state = 'stopped', vd.last_error = ''
where vd.id = vdt.vdiff_id and vd.id = %d and vd.state != 'completed'`
where vd.id = vdt.vdiff_id and vd.id = %a and vd.state != 'completed'`
sqlGetVReplicationEntry = "select * from _vt.vreplication %s"
sqlGetVDiffsToRun = "select * from _vt.vdiff where state in ('started','pending')" // what VDiffs have not been stopped or completed
sqlGetVDiffsToRetry = "select * from _vt.vdiff where state = 'error' and json_unquote(json_extract(options, '$.core_options.auto_retry')) = 'true'"
sqlGetVDiffID = "select id as id from _vt.vdiff where vdiff_uuid = %s"
sqlGetVDiffID = "select id as id from _vt.vdiff where vdiff_uuid = %a"
sqlGetAllVDiffs = "select * from _vt.vdiff order by id desc"
sqlGetTableRows = "select table_rows as table_rows from INFORMATION_SCHEMA.TABLES where table_schema = %a and table_name = %a"
sqlGetAllTableRows = "select table_name as table_name, table_rows as table_rows from INFORMATION_SCHEMA.TABLES where table_schema = %s and table_name in (%s)"

sqlNewVDiffTable = "insert into _vt.vdiff_table(vdiff_id, table_name, state, table_rows) values(%d, %s, 'pending', %d)"
sqlNewVDiffTable = "insert into _vt.vdiff_table(vdiff_id, table_name, state, table_rows) values(%a, %a, 'pending', %a)"
sqlGetVDiffTable = `select vdt.lastpk as lastpk, vdt.mismatch as mismatch, vdt.report as report
from _vt.vdiff as vd inner join _vt.vdiff_table as vdt on (vd.id = vdt.vdiff_id)
where vdt.vdiff_id = %d and vdt.table_name = %s`
where vdt.vdiff_id = %a and vdt.table_name = %a`
sqlUpdateTableRows = "update _vt.vdiff_table set table_rows = %a where vdiff_id = %a and table_name = %a"
sqlUpdateTableProgress = "update _vt.vdiff_table set rows_compared = %d, lastpk = %s, report = %s where vdiff_id = %d and table_name = %s"
sqlUpdateTableNoProgress = "update _vt.vdiff_table set rows_compared = %d, report = %s where vdiff_id = %d and table_name = %s"
sqlUpdateTableState = "update _vt.vdiff_table set state = %s where vdiff_id = %d and table_name = %s"
sqlUpdateTableStateAndReport = "update _vt.vdiff_table set state = %s, rows_compared = %d, report = %s where vdiff_id = %d and table_name = %s"
sqlUpdateTableMismatch = "update _vt.vdiff_table set mismatch = true where vdiff_id = %d and table_name = %s"
sqlUpdateTableProgress = "update _vt.vdiff_table set rows_compared = %a, lastpk = %a, report = %a where vdiff_id = %a and table_name = %a"
sqlUpdateTableNoProgress = "update _vt.vdiff_table set rows_compared = %a, report = %a where vdiff_id = %a and table_name = %a"
sqlUpdateTableState = "update _vt.vdiff_table set state = %a where vdiff_id = %a and table_name = %a"
sqlUpdateTableStateAndReport = "update _vt.vdiff_table set state = %a, rows_compared = %a, report = %a where vdiff_id = %a and table_name = %a"
sqlUpdateTableMismatch = "update _vt.vdiff_table set mismatch = true where vdiff_id = %a and table_name = %a"

sqlGetIncompleteTables = "select table_name as table_name from _vt.vdiff_table where vdiff_id = %d and state != 'completed'"
sqlGetIncompleteTables = "select table_name as table_name from _vt.vdiff_table where vdiff_id = %a and state != 'completed'"
)
Loading

0 comments on commit 52dcb8f

Please sign in to comment.