From 60aeb16df52841b1283fe36cdc30b06d37bdd480 Mon Sep 17 00:00:00 2001 From: jiyfhust Date: Thu, 20 Apr 2023 00:11:19 +0800 Subject: [PATCH] executor: fix uint type overflow on generated column not compatible with mysql (#40157) close pingcap/tidb#40066 --- executor/insert_common.go | 11 ++++++++- executor/write_test.go | 50 +++++++++++++++++++++++++++++++++++++++ table/column.go | 19 +++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/executor/insert_common.go b/executor/insert_common.go index 751f83c071eda..ab463544fffd0 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -690,6 +690,8 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue return nil, err } } + sc := e.ctx.GetSessionVars().StmtCtx + warnCnt := int(sc.WarningCount()) for i, gCol := range gCols { colIdx := gCol.ColumnInfo.Offset val, err := e.GenExprs[i].Eval(chunk.MutRowFromDatums(row).ToRow()) @@ -697,9 +699,16 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue return nil, err } row[colIdx], err = table.CastValue(e.ctx, val, gCol.ToInfo(), false, false) - if err != nil { + if err = e.handleErr(gCol, &val, rowIdx, err); err != nil { return nil, err } + if newWarnings := sc.TruncateWarnings(warnCnt); len(newWarnings) > 0 { + for k := range newWarnings { + newWarnings[k].Err = completeInsertErr(gCol.ColumnInfo, &val, rowIdx, newWarnings[k].Err) + } + sc.AppendWarnings(newWarnings) + warnCnt += len(newWarnings) + } // Handle the bad null error. if err = gCol.HandleBadNull(&row[colIdx], e.ctx.GetSessionVars().StmtCtx); err != nil { return nil, err diff --git a/executor/write_test.go b/executor/write_test.go index ce7647a58be7e..1f74662d05d74 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -4300,6 +4300,56 @@ func TestIssueInsertPrefixIndexForNonUTF8Collation(t *testing.T) { tk.MustGetErrCode("insert into t3 select 'abc d'", 1062) } +func TestIssue40066(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("create database test_40066") + defer tk.MustExec("drop database test_40066") + tk.MustExec("use test_40066") + tk.MustExec("set @orig_sql_mode = @@sql_mode;") + defer tk.MustExec("set @@sql_mode = @orig_sql_mode;") + + tk.MustExec(`create table t_int(column1 int, column2 int unsigned generated always as(column1-100));`) + tk.MustExec("set @@sql_mode = DEFAULT;") + tk.MustGetErrMsg("insert into t_int(column1) values (99);", "[types:1264]Out of range value for column 'column2' at row 1") + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t_int(column1) values (99);") + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1264 Out of range value for column 'column2' at row 1")) + tk.MustQuery("select * from t_int;").Check(testkit.Rows("99 0")) + + tk.MustExec(`create table t_float(column1 float, column2 int unsigned generated always as(column1-100));`) + tk.MustExec("set @@sql_mode = DEFAULT;") + tk.MustGetErrMsg("insert into t_float(column1) values (12.95);", "[types:1264]Out of range value for column 'column2' at row 1") + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t_float(column1) values (12.95);") + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1264 Out of range value for column 'column2' at row 1")) + tk.MustQuery("select * from t_float;").Check(testkit.Rows("12.95 0")) + + tk.MustExec(`create table t_decimal(column1 decimal(20,10), column2 int unsigned generated always as(column1-100));`) + tk.MustExec("set @@sql_mode = DEFAULT;") + tk.MustGetErrMsg("insert into t_decimal(column1) values (123.456e-2);", "[types:1264]Out of range value for column 'column2' at row 1") + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t_decimal(column1) values (123.456e-2);") + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1264 Out of range value for column 'column2' at row 1")) + tk.MustQuery("select * from t_decimal;").Check(testkit.Rows("1.2345600000 0")) + + tk.MustExec(`create table t_varchar(column1 varchar(10), column2 int unsigned generated always as(column1-100));`) + tk.MustExec("set @@sql_mode = DEFAULT;") + tk.MustGetErrMsg("insert into t_varchar(column1) values ('87.12');", "[types:1264]Out of range value for column 'column2' at row 1") + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t_varchar(column1) values ('87.12');") + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1264 Out of range value for column 'column2' at row 1")) + tk.MustQuery("select * from t_varchar;").Check(testkit.Rows("87.12 0")) + + tk.MustExec(`create table t_union(column1 float, column2 int unsigned generated always as(column1-100), column3 float unsigned generated always as(column1-100));`) + tk.MustExec("set @@sql_mode = DEFAULT;") + tk.MustGetErrMsg("insert into t_union(column1) values (12.95);", "[types:1264]Out of range value for column 'column2' at row 1") + tk.MustExec("set @@sql_mode = '';") + tk.MustExec("insert into t_union(column1) values (12.95);") + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1264 Out of range value for column 'column2' at row 1", "Warning 1264 Out of range value for column 'column3' at row 1")) + tk.MustQuery("select * from t_union;").Check(testkit.Rows("12.95 0 0")) +} + func TestMutipleReplaceAndInsertInOneSession(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/table/column.go b/table/column.go index 649b366d98de8..4a349de55ec44 100644 --- a/table/column.go +++ b/table/column.go @@ -696,6 +696,25 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd if err != nil { return err } + + // Clip to zero if get negative value after cast to unsigned. + if mysql.HasUnsignedFlag(colInfos[idx].FieldType.GetFlag()) && !castDatum.IsNull() && !sctx.GetSessionVars().StmtCtx.ShouldClipToZero() { + switch datum.Kind() { + case types.KindInt64: + if datum.GetInt64() < 0 { + castDatum = GetZeroValue(colInfos[idx]) + } + case types.KindFloat32, types.KindFloat64: + if types.RoundFloat(datum.GetFloat64()) < 0 { + castDatum = GetZeroValue(colInfos[idx]) + } + case types.KindMysqlDecimal: + if datum.GetMysqlDecimal().IsNegative() { + castDatum = GetZeroValue(colInfos[idx]) + } + } + } + // Handle the bad null error. if (mysql.HasNotNullFlag(colInfos[idx].GetFlag()) || mysql.HasPreventNullInsertFlag(colInfos[idx].GetFlag())) && castDatum.IsNull() { castDatum = GetZeroValue(colInfos[idx])