From 5900a7d8990abef9cd742d609ad1d7a88a258cb9 Mon Sep 17 00:00:00 2001 From: Hangjie Mo Date: Fri, 30 Dec 2022 20:20:17 +0800 Subject: [PATCH] This is an automated cherry-pick of #39888 Signed-off-by: ti-chi-bot --- ddl/db_integration_test.go | 44 ++++++++++++++++++++++++++++++++++++++ ddl/ddl_api.go | 27 +++++++++++++++-------- ddl/generated_column.go | 15 ++++++++++--- parser/ast/ddl_test.go | 15 +++++++++++++ parser/ast/dml.go | 13 +++++++++++ parser/ast/expressions.go | 4 ++-- parser/format/format.go | 17 +++++++++++++++ 7 files changed, 121 insertions(+), 14 deletions(-) diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index 8f8e33054a73e..4a7ad444a5b2a 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -2306,8 +2306,52 @@ func (s *testIntegrationSuite3) TestSqlFunctionsInGeneratedColumns(c *C) { tk.MustExec("create table t (a int, b int as ((a)))") } +<<<<<<< HEAD func (s *testIntegrationSuite3) TestParserIssue284(c *C) { tk := testkit.NewTestKit(c, s.store) +======= +func TestSchemaNameAndTableNameInGeneratedExpr(t *testing.T) { + store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("create database if not exists test") + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + + tk.MustExec("create table t(a int, b int as (lower(test.t.a)))") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` int(11) GENERATED ALWAYS AS (lower(`a`)) VIRTUAL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + tk.MustExec("drop table t") + tk.MustExec("create table t(a int)") + tk.MustExec("alter table t add column b int as (lower(test.t.a))") + tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` int(11) GENERATED ALWAYS AS (lower(`a`)) VIRTUAL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + tk.MustGetErrCode("alter table t add index idx((lower(test.t1.a)))", errno.ErrBadField) + + tk.MustExec("drop table t") + tk.MustGetErrCode("create table t(a int, b int as (lower(test1.t.a)))", errno.ErrWrongDBName) + + tk.MustExec("create table t(a int)") + tk.MustGetErrCode("alter table t add column b int as (lower(test.t1.a))", errno.ErrWrongTableName) + + tk.MustExec("alter table t add column c int") + tk.MustGetErrCode("alter table t modify column c int as (test.t1.a + 1) stored", errno.ErrWrongTableName) + + tk.MustExec("alter table t add column d int as (lower(test.T.a))") + tk.MustExec("alter table t add column e int as (lower(Test.t.a))") +} + +func TestParserIssue284(t *testing.T) { + store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) + + tk := testkit.NewTestKit(t, store) +>>>>>>> 702a5598f9 (ddl, parser: make generated column and expression index same as MySQL (#39888)) tk.MustExec("use test") tk.MustExec("create table test.t_parser_issue_284(c1 int not null primary key)") _, err := tk.Exec("create table test.t_parser_issue_284_2(id int not null primary key, c1 int not null, constraint foreign key (c1) references t_parser_issue_284(c1))") diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 88716d1942069..20ea027cb537c 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -763,7 +763,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o var sb strings.Builder restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) for _, v := range colDef.Options { @@ -824,7 +824,10 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o } col.GeneratedExprString = sb.String() col.GeneratedStored = v.Stored - _, dependColNames := findDependedColumnNames(colDef) + _, dependColNames, err := findDependedColumnNames(model.NewCIStr(""), model.NewCIStr(""), colDef) + if err != nil { + return nil, nil, errors.Trace(err) + } col.Dependences = dependColNames case ast.ColumnOptionCollate: if field_types.HasCharset(colDef.Tp) { @@ -1197,7 +1200,7 @@ func IsAutoRandomColumnID(tblInfo *model.TableInfo, colID int64) bool { return tblInfo.PKIsHandle && tblInfo.ContainsAutoRandomBits() && tblInfo.GetPkColInfo().ID == colID } -func checkGeneratedColumn(ctx sessionctx.Context, colDefs []*ast.ColumnDef) error { +func checkGeneratedColumn(ctx sessionctx.Context, schemaName model.CIStr, tableName model.CIStr, colDefs []*ast.ColumnDef) error { var colName2Generation = make(map[string]columnGenerationInDDL, len(colDefs)) var exists bool var autoIncrementColumn string @@ -1212,7 +1215,10 @@ func checkGeneratedColumn(ctx sessionctx.Context, colDefs []*ast.ColumnDef) erro if containsColumnOption(colDef, ast.ColumnOptionAutoIncrement) { exists, autoIncrementColumn = true, colDef.Name.Name.L } - generated, depCols := findDependedColumnNames(colDef) + generated, depCols, err := findDependedColumnNames(schemaName, tableName, colDef) + if err != nil { + return errors.Trace(err) + } if !generated { colName2Generation[colDef.Name.Name.L] = columnGenerationInDDL{ position: i, @@ -1732,7 +1738,7 @@ func checkTableInfoValidExtra(tbInfo *model.TableInfo) error { func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { // All of these rely on the AST structure of expressions, which were // lost in the model (got serialized into strings). - if err := checkGeneratedColumn(ctx, s.Cols); err != nil { + if err := checkGeneratedColumn(ctx, s.Table.Schema, tbInfo.Name, s.Cols); err != nil { return errors.Trace(err) } if tbInfo.Partition != nil { @@ -3042,7 +3048,10 @@ func checkAndCreateNewColumn(ctx sessionctx.Context, ti ast.Ident, schema *model return nil, ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Adding generated stored column through ALTER TABLE") } - _, dependColNames := findDependedColumnNames(specNewColumn) + _, dependColNames, err := findDependedColumnNames(schema.Name, t.Meta().Name, specNewColumn) + if err != nil { + return nil, errors.Trace(err) + } if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { if err = checkAutoIncrementRef(specNewColumn.Name.Name.L, dependColNames, t.Meta()); err != nil { return nil, errors.Trace(err) @@ -4002,7 +4011,7 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error { var sb strings.Builder restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutSchemaName restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) var hasDefaultValue, setOnUpdateNow bool @@ -4242,7 +4251,7 @@ func (d *ddl) getModifiableColumnJob(ctx context.Context, sctx sessionctx.Contex } // As same with MySQL, we don't support modifying the stored status for generated columns. - if err = checkModifyGeneratedColumn(sctx, t, col, newCol, specNewColumn, spec.Position); err != nil { + if err = checkModifyGeneratedColumn(sctx, schema.Name, t, col, newCol, specNewColumn, spec.Position); err != nil { return nil, errors.Trace(err) } @@ -5348,7 +5357,7 @@ func buildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as var sb strings.Builder restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) sb.Reset() err := idxPart.Expr.Restore(restoreCtx) diff --git a/ddl/generated_column.go b/ddl/generated_column.go index 232828a0c107e..aaba7d7d7be42 100644 --- a/ddl/generated_column.go +++ b/ddl/generated_column.go @@ -121,13 +121,19 @@ func findPositionRelativeColumn(cols []*table.Column, pos *ast.ColumnPosition) ( // findDependedColumnNames returns a set of string, which indicates // the names of the columns that are depended by colDef. -func findDependedColumnNames(colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}) { +func findDependedColumnNames(schemaName model.CIStr, tableName model.CIStr, colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}, err error) { colsMap = make(map[string]struct{}) for _, option := range colDef.Options { if option.Tp == ast.ColumnOptionGenerated { generated = true colNames := findColumnNamesInExpr(option.Expr) for _, depCol := range colNames { + if depCol.Schema.L != "" && schemaName.L != "" && depCol.Schema.L != schemaName.L { + return false, nil, dbterror.ErrWrongDBName.GenWithStackByArgs(depCol.Schema.O) + } + if depCol.Table.L != "" && tableName.L != "" && depCol.Table.L != tableName.L { + return false, nil, dbterror.ErrWrongTableName.GenWithStackByArgs(depCol.Table.O) + } colsMap[depCol.Name.L] = struct{}{} } break @@ -191,7 +197,7 @@ func (c *generatedColumnChecker) Leave(inNode ast.Node) (node ast.Node, ok bool) // 3. check if the modified expr contains non-deterministic functions // 4. check whether new column refers to any auto-increment columns. // 5. check if the new column is indexed or stored -func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef, pos *ast.ColumnPosition) error { +func checkModifyGeneratedColumn(sctx sessionctx.Context, schemaName model.CIStr, tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef, pos *ast.ColumnPosition) error { // rule 1. oldColIsStored := !oldCol.IsGenerated() || oldCol.GeneratedStored newColIsStored := !newCol.IsGenerated() || newCol.GeneratedStored @@ -251,7 +257,10 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol } // rule 4. - _, dependColNames := findDependedColumnNames(newColDef) + _, dependColNames, err := findDependedColumnNames(schemaName, tbl.Meta().Name, newColDef) + if err != nil { + return errors.Trace(err) + } if !sctx.GetSessionVars().EnableAutoIncrementInGenerated { if err := checkAutoIncrementRef(newColDef.Name.Name.L, dependColNames, tbl.Meta()); err != nil { return errors.Trace(err) diff --git a/parser/ast/ddl_test.go b/parser/ast/ddl_test.go index af4ecc5a6ca85..02192e0a0cded 100644 --- a/parser/ast/ddl_test.go +++ b/parser/ast/ddl_test.go @@ -249,6 +249,21 @@ func TestDDLColumnOptionRestore(t *testing.T) { runNodeRestoreTest(t, testCases, "CREATE TABLE child (id INT %s)", extractNodeFunc) } +func TestGeneratedRestore(t *testing.T) { + testCases := []NodeRestoreTestCase{ + {"generated always as(id + 1)", "GENERATED ALWAYS AS(`id`+1) VIRTUAL"}, + {"generated always as(id + 1) virtual", "GENERATED ALWAYS AS(`id`+1) VIRTUAL"}, + {"generated always as(id + 1) stored", "GENERATED ALWAYS AS(`id`+1) STORED"}, + {"generated always as(lower(id)) stored", "GENERATED ALWAYS AS(LOWER(`id`)) STORED"}, + {"generated always as(lower(child.id)) stored", "GENERATED ALWAYS AS(LOWER(`id`)) STORED"}, + } + extractNodeFunc := func(node Node) Node { + return node.(*CreateTableStmt).Cols[0].Options[0] + } + runNodeRestoreTestWithFlagsStmtChange(t, testCases, "CREATE TABLE child (id INT %s)", extractNodeFunc, + format.DefaultRestoreFlags|format.RestoreWithoutSchemaName|format.RestoreWithoutTableName) +} + func TestDDLColumnDefRestore(t *testing.T) { testCases := []NodeRestoreTestCase{ // for type diff --git a/parser/ast/dml.go b/parser/ast/dml.go index 5b8c633ec61b0..ceb79a19bdff5 100644 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -279,6 +279,7 @@ func (*TableName) resultSet() {} // Restore implements Node interface. func (n *TableName) restoreName(ctx *format.RestoreCtx) { +<<<<<<< HEAD if n.Schema.String() != "" { ctx.WriteName(n.Schema.String()) ctx.WritePlain(".") @@ -286,7 +287,19 @@ func (n *TableName) restoreName(ctx *format.RestoreCtx) { // Try CTE, for a CTE table name, we shouldn't write the database name. if !ctx.IsCTETableName(n.Name.L) { ctx.WriteName(ctx.DefaultDB) +======= + if !ctx.Flags.HasWithoutSchemaNameFlag() { + // restore db name + if n.Schema.String() != "" { + ctx.WriteName(n.Schema.String()) +>>>>>>> 702a5598f9 (ddl, parser: make generated column and expression index same as MySQL (#39888)) ctx.WritePlain(".") + } else if ctx.DefaultDB != "" { + // Try CTE, for a CTE table name, we shouldn't write the database name. + if !ctx.IsCTETableName(n.Name.L) { + ctx.WriteName(ctx.DefaultDB) + ctx.WritePlain(".") + } } } ctx.WriteName(n.Name.String()) diff --git a/parser/ast/expressions.go b/parser/ast/expressions.go index 66f40eb205952..0d787fb65d642 100644 --- a/parser/ast/expressions.go +++ b/parser/ast/expressions.go @@ -512,11 +512,11 @@ type ColumnName struct { // Restore implements Node interface. func (n *ColumnName) Restore(ctx *format.RestoreCtx) error { - if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) { + if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) && !ctx.Flags.HasWithoutSchemaNameFlag() { ctx.WriteName(n.Schema.O) ctx.WritePlain(".") } - if n.Table.O != "" { + if n.Table.O != "" && !ctx.Flags.HasWithoutTableNameFlag() { ctx.WriteName(n.Table.O) ctx.WritePlain(".") } diff --git a/parser/format/format.go b/parser/format/format.go index 4141b0baf119a..47c786b416fee 100644 --- a/parser/format/format.go +++ b/parser/format/format.go @@ -224,6 +224,13 @@ const ( RestoreStringWithoutDefaultCharset RestoreTiDBSpecialComment +<<<<<<< HEAD +======= + SkipPlacementRuleForRestore + RestoreWithTTLEnableOff + RestoreWithoutSchemaName + RestoreWithoutTableName +>>>>>>> 702a5598f9 (ddl, parser: make generated column and expression index same as MySQL (#39888)) ) const ( @@ -234,6 +241,16 @@ func (rf RestoreFlags) has(flag RestoreFlags) bool { return rf&flag != 0 } +// HasWithoutSchemaNameFlag returns a boolean indicating when `rf` has `RestoreWithoutSchemaName` flag. +func (rf RestoreFlags) HasWithoutSchemaNameFlag() bool { + return rf.has(RestoreWithoutSchemaName) +} + +// HasWithoutTableNameFlag returns a boolean indicating when `rf` has `RestoreWithoutTableName` flag. +func (rf RestoreFlags) HasWithoutTableNameFlag() bool { + return rf.has(RestoreWithoutTableName) +} + // HasStringSingleQuotesFlag returns a boolean indicating when `rf` has `RestoreStringSingleQuotes` flag. func (rf RestoreFlags) HasStringSingleQuotesFlag() bool { return rf.has(RestoreStringSingleQuotes)