Skip to content

Commit

Permalink
executor: optimize load data assignment expression (#46563)
Browse files Browse the repository at this point in the history
close #46081
  • Loading branch information
lance6716 authored Sep 4, 2023
1 parent 4cb399e commit 6f1bdf7
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 8 deletions.
1 change: 1 addition & 0 deletions executor/importer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ go_library(
"//parser/terror",
"//planner/core",
"//sessionctx",
"//sessionctx/stmtctx",
"//sessionctx/variable",
"//table",
"//table/tables",
Expand Down
22 changes: 22 additions & 0 deletions executor/importer/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ import (
"github.com/pingcap/tidb/br/pkg/lightning/mydump"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/executor/asyncloaddata"
"github.com/pingcap/tidb/expression"
tidbkv "github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -936,6 +938,26 @@ func (e *LoadDataController) toMyDumpFiles() []mydump.FileInfo {
return res
}

// CreateColAssignExprs creates the column assignment expressions using session context.
// RewriteAstExpr will write ast node in place(due to xxNode.Accept), but it doesn't change node content,
// so we sync it.
func (e *LoadDataController) CreateColAssignExprs(sctx sessionctx.Context) ([]expression.Expression, []stmtctx.SQLWarn, error) {
res := make([]expression.Expression, 0, len(e.ColumnAssignments))
allWarnings := []stmtctx.SQLWarn{}
for _, assign := range e.ColumnAssignments {
newExpr, err := expression.RewriteAstExpr(sctx, assign.Expr, nil, nil, false)
// col assign expr warnings is static, we should generate it for each row processed.
// so we save it and clear it here.
allWarnings = append(allWarnings, sctx.GetSessionVars().StmtCtx.GetWarnings()...)
sctx.GetSessionVars().StmtCtx.SetWarnings(nil)
if err != nil {
return nil, nil, err
}
res = append(res, newExpr)
}
return res, allWarnings, nil
}

// JobImportParam is the param of the job import.
type JobImportParam struct {
Job *asyncloaddata.Job
Expand Down
30 changes: 22 additions & 8 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,16 @@ func (ji *logicalJobImporter) initEncodeCommitWorkers(e *LoadDataWorker) (err er
return err2
}
createdSessions = append(createdSessions, commitCore.ctx)
colAssignExprs, exprWarnings, err2 := e.controller.CreateColAssignExprs(encodeCore.ctx)
if err2 != nil {
return err2
}
encode := &encodeWorker{
InsertValues: encodeCore,
controller: e.controller,
killed: &e.UserSctx.GetSessionVars().Killed,
InsertValues: encodeCore,
controller: e.controller,
colAssignExprs: colAssignExprs,
exprWarnings: exprWarnings,
killed: &e.UserSctx.GetSessionVars().Killed,
}
encode.resetBatch()
encodeWorkers = append(encodeWorkers, encode)
Expand Down Expand Up @@ -627,9 +633,13 @@ func (ji *logicalJobImporter) Close() error {
// encodeWorker is a sub-worker of LoadDataWorker that dedicated to encode data.
type encodeWorker struct {
*InsertValues
controller *importer.LoadDataController
killed *uint32
rows [][]types.Datum
controller *importer.LoadDataController
colAssignExprs []expression.Expression
// sessionCtx generate warnings when rewrite AST node into expression.
// we should generate such warnings for each row encoded.
exprWarnings []stmtctx.SQLWarn
killed *uint32
rows [][]types.Datum
}

// processStream always trys to build a parser from channel and process it. When
Expand Down Expand Up @@ -818,9 +828,9 @@ func (w *encodeWorker) parserData2TableData(

row = append(row, parserData[i])
}
for i := 0; i < len(w.controller.ColumnAssignments); i++ {
for i := 0; i < len(w.colAssignExprs); i++ {
// eval expression of `SET` clause
d, err := expression.EvalAstExpr(w.ctx, w.controller.ColumnAssignments[i].Expr)
d, err := w.colAssignExprs[i].Eval(chunk.Row{})
if err != nil {
if w.controller.Restrictive {
return nil, err
Expand All @@ -830,6 +840,10 @@ func (w *encodeWorker) parserData2TableData(
row = append(row, d)
}

if len(w.exprWarnings) > 0 {
w.ctx.GetSessionVars().StmtCtx.AppendWarnings(w.exprWarnings)
}

// a new row buffer will be allocated in getRow
newRow, err := w.getRow(ctx, row)
if err != nil {
Expand Down
67 changes: 67 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,73 @@ func (cli *testServerClient) runTestLoadData(t *testing.T, server *Server) {
require.NoError(t, rows.Close())
dbt.MustExec("drop table if exists pn")
})

err = fp.Close()
require.NoError(t, err)
err = os.Remove(path)
require.NoError(t, err)

fp, err = os.Create(path)
require.NoError(t, err)
require.NotNil(t, fp)

_, err = fp.WriteString(
`1,2` + "\n" +
`1,2,,4` + "\n" +
`1,2,3` + "\n" +
`,,,` + "\n" +
`,,3` + "\n" +
`1,,,4` + "\n")
require.NoError(t, err)

nullInt32 := func(val int32, valid bool) sql.NullInt32 {
return sql.NullInt32{Int32: val, Valid: valid}
}
expects := []struct {
col1 sql.NullInt32
col2 sql.NullInt32
col3 sql.NullInt32
col4 sql.NullInt32
}{
{nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(0, false)},
{nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(4, true)},
{nullInt32(1, true), nullInt32(2, true), nullInt32(3, true), nullInt32(0, false)},
{nullInt32(0, true), nullInt32(0, false), nullInt32(0, false), nullInt32(0, false)},
{nullInt32(0, true), nullInt32(0, false), nullInt32(3, true), nullInt32(0, false)},
{nullInt32(1, true), nullInt32(0, false), nullInt32(0, false), nullInt32(4, true)},
}

cli.runTestsOnNewDB(t, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Params["sql_mode"] = "''"
}, "LoadData", func(dbt *testkit.DBTestKit) {
dbt.MustExec("drop table if exists pn")
dbt.MustExec("create table pn (c1 int, c2 int, c3 int, c4 int)")
dbt.MustExec("set @@tidb_dml_batch_size = 1")
_, err1 := dbt.GetDB().Exec(fmt.Sprintf(`load data local infile %q into table pn FIELDS TERMINATED BY ',' (c1, @val2, @val3, @val4)
SET c2 = NULLIF(@val2, ''), c3 = NULLIF(@val3, ''), c4 = NULLIF(@val4, '')`, path))
require.NoError(t, err1)
var (
a sql.NullInt32
b sql.NullInt32
c sql.NullInt32
d sql.NullInt32
)
rows := dbt.MustQuery("select * from pn")
for _, expect := range expects {
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&a, &b, &c, &d)
require.NoError(t, err)
require.Equal(t, expect.col1, a)
require.Equal(t, expect.col2, b)
require.Equal(t, expect.col3, c)
require.Equal(t, expect.col4, d)
}

require.Falsef(t, rows.Next(), "unexpected data")
require.NoError(t, rows.Close())
dbt.MustExec("drop table if exists pn")
})
}

func (cli *testServerClient) runTestConcurrentUpdate(t *testing.T) {
Expand Down

0 comments on commit 6f1bdf7

Please sign in to comment.