diff --git a/executor/builder.go b/executor/builder.go index 946d25a12d100..582a0fd0c75e3 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -944,6 +944,11 @@ func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) Executor { b.err = err return nil } + err = loadDataInfo.initColAssignExprs() + if err != nil { + b.err = err + return nil + } loadDataExec := &LoadDataExec{ baseExecutor: newBaseExecutor(b.ctx, nil, v.ID()), IsLocal: v.IsLocal, diff --git a/executor/load_data.go b/executor/load_data.go index a5db464ce705e..c98bcff077bdc 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -29,6 +29,7 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -116,7 +117,11 @@ type LoadDataInfo struct { rows [][]types.Datum Drained bool - ColumnAssignments []*ast.Assignment + ColumnAssignments []*ast.Assignment + ColumnAssignmentExprs []expression.Expression + // sessionCtx generate warnings when rewrite AST node into expression. + // we should generate such warnings for each row encoded. + exprWarnings []stmtctx.SQLWarn ColumnsAndUserVars []*ast.ColumnNameOrUserVar FieldMappings []*FieldMapping @@ -211,6 +216,23 @@ func (e *LoadDataInfo) initLoadColumns(columnNames []string) error { return nil } +// initColAssignExprs creates the column assignment expressions using session context. +// RewriteAstExpr will write ast node in place(due to xxNode.Accept), but it doesn't change node content, +func (e *LoadDataInfo) initColAssignExprs() error { + for _, assign := range e.ColumnAssignments { + newExpr, err := expression.RewriteAstExpr(e.Ctx, assign.Expr, nil, nil) + if err != nil { + return err + } + // col assign expr warnings is static, we should generate it for each row processed. + // so we save it and clear it here. + e.exprWarnings = append(e.exprWarnings, e.Ctx.GetSessionVars().StmtCtx.GetWarnings()...) + e.Ctx.GetSessionVars().StmtCtx.SetWarnings(nil) + e.ColumnAssignmentExprs = append(e.ColumnAssignmentExprs, newExpr) + } + return nil +} + // initFieldMappings make a field mapping slice to implicitly map input field to table column or user defined variable // the slice's order is the same as the order of the input fields. // Returns a slice of same ordered column names without user defined variable names. @@ -664,15 +686,19 @@ func (e *LoadDataInfo) colsToRow(ctx context.Context, cols []field) []types.Datu row = append(row, types.NewDatum(string(cols[i].str))) } - for i := 0; i < len(e.ColumnAssignments); i++ { + + for i := 0; i < len(e.ColumnAssignmentExprs); i++ { // eval expression of `SET` clause - d, err := expression.EvalAstExpr(e.Ctx, e.ColumnAssignments[i].Expr) + d, err := e.ColumnAssignmentExprs[i].Eval(chunk.Row{}) if err != nil { e.handleWarning(err) return nil } row = append(row, d) } + if len(e.exprWarnings) > 0 { + e.Ctx.GetSessionVars().StmtCtx.AppendWarnings(e.exprWarnings) + } // a new row buffer will be allocated in getRow newRow, err := e.getRow(ctx, row) diff --git a/server/server_test.go b/server/server_test.go index fe63a6847670c..7471516f62d80 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1576,6 +1576,73 @@ func (cli *testServerClient) runTestLoadData(t *testing.T, server *Server) { require.NoError(t, rows.Close()) dbt.MustExec("drop table if exists pn") }) + + err = fp.Close() + require.NoError(t, err) + err = os.Remove(path) + require.NoError(t, err) + + fp, err = os.Create(path) + require.NoError(t, err) + require.NotNil(t, fp) + + _, err = fp.WriteString( + `1,2` + "\n" + + `1,2,,4` + "\n" + + `1,2,3` + "\n" + + `,,,` + "\n" + + `,,3` + "\n" + + `1,,,4` + "\n") + require.NoError(t, err) + + nullInt32 := func(val int32, valid bool) sql.NullInt32 { + return sql.NullInt32{Int32: val, Valid: valid} + } + expects := []struct { + col1 sql.NullInt32 + col2 sql.NullInt32 + col3 sql.NullInt32 + col4 sql.NullInt32 + }{ + {nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(0, false)}, + {nullInt32(1, true), nullInt32(2, true), nullInt32(0, false), nullInt32(4, true)}, + {nullInt32(1, true), nullInt32(2, true), nullInt32(3, true), nullInt32(0, false)}, + {nullInt32(0, true), nullInt32(0, false), nullInt32(0, false), nullInt32(0, false)}, + {nullInt32(0, true), nullInt32(0, false), nullInt32(3, true), nullInt32(0, false)}, + {nullInt32(1, true), nullInt32(0, false), nullInt32(0, false), nullInt32(4, true)}, + } + + cli.runTestsOnNewDB(t, func(config *mysql.Config) { + config.AllowAllFiles = true + config.Params["sql_mode"] = "''" + }, "LoadData", func(dbt *testkit.DBTestKit) { + dbt.MustExec("drop table if exists pn") + dbt.MustExec("create table pn (c1 int, c2 int, c3 int, c4 int)") + dbt.MustExec("set @@tidb_dml_batch_size = 1") + _, err1 := dbt.GetDB().Exec(fmt.Sprintf(`load data local infile %q into table pn FIELDS TERMINATED BY ',' (c1, @val2, @val3, @val4) + SET c2 = NULLIF(@val2, ''), c3 = NULLIF(@val3, ''), c4 = NULLIF(@val4, '')`, path)) + require.NoError(t, err1) + var ( + a sql.NullInt32 + b sql.NullInt32 + c sql.NullInt32 + d sql.NullInt32 + ) + rows := dbt.MustQuery("select * from pn") + for _, expect := range expects { + require.Truef(t, rows.Next(), "unexpected data") + err = rows.Scan(&a, &b, &c, &d) + require.NoError(t, err) + require.Equal(t, expect.col1, a) + require.Equal(t, expect.col2, b) + require.Equal(t, expect.col3, c) + require.Equal(t, expect.col4, d) + } + + require.Falsef(t, rows.Next(), "unexpected data") + require.NoError(t, rows.Close()) + dbt.MustExec("drop table if exists pn") + }) } func (cli *testServerClient) runTestConcurrentUpdate(t *testing.T) {