Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddl, parser: make generated column and expression index same as MySQL #39888

Merged
merged 16 commits into from
Dec 30, 2022
Merged
37 changes: 37 additions & 0 deletions ddl/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,43 @@ func TestSqlFunctionsInGeneratedColumns(t *testing.T) {
tk.MustExec("create table t (a int, b int as ((a)))")
}

func TestSchemaNameAndTableNameInGeneratedExpr(t *testing.T) {
store := testkit.CreateMockStore(t, mockstore.WithDDLChecker())

tk := testkit.NewTestKit(t, store)
tk.MustExec("create database if not exists test")
tk.MustExec("use test")
tk.MustExec("drop table if exists t")

tk.MustExec("create table t(a int, b int as (lower(test.t.a)))")
tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" +
" `a` int(11) DEFAULT NULL,\n" +
" `b` int(11) GENERATED ALWAYS AS (lower(`a`)) VIRTUAL\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))

tk.MustExec("drop table t")
tk.MustExec("create table t(a int)")
tk.MustExec("alter table t add column b int as (lower(test.t.a))")
tk.MustQuery("show create table t").Check(testkit.Rows("t CREATE TABLE `t` (\n" +
" `a` int(11) DEFAULT NULL,\n" +
" `b` int(11) GENERATED ALWAYS AS (lower(`a`)) VIRTUAL\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"))

tk.MustGetErrCode("alter table t add index idx((lower(test.t1.a)))", errno.ErrBadField)

tk.MustExec("drop table t")
tk.MustGetErrCode("create table t(a int, b int as (lower(test1.t.a)))", errno.ErrWrongDBName)

tk.MustExec("create table t(a int)")
tk.MustGetErrCode("alter table t add column b int as (lower(test.t1.a))", errno.ErrWrongTableName)

tk.MustExec("alter table t add column c int")
tk.MustGetErrCode("alter table t modify column c int as (test.t1.a + 1) stored", errno.ErrWrongTableName)

tk.MustExec("alter table t add column d int as (lower(test.T.a))")
tk.MustExec("alter table t add column e int as (lower(Test.t.a))")
}

func TestParserIssue284(t *testing.T) {
store := testkit.CreateMockStore(t, mockstore.WithDDLChecker())

Expand Down
27 changes: 18 additions & 9 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o

var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)

for _, v := range colDef.Options {
Expand Down Expand Up @@ -1142,7 +1142,10 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o
}
col.GeneratedExprString = sb.String()
col.GeneratedStored = v.Stored
_, dependColNames := findDependedColumnNames(colDef)
_, dependColNames, err := findDependedColumnNames(model.NewCIStr(""), model.NewCIStr(""), colDef)
if err != nil {
return nil, nil, errors.Trace(err)
}
col.Dependences = dependColNames
case ast.ColumnOptionCollate:
if field_types.HasCharset(colDef.Tp) {
Expand Down Expand Up @@ -1561,7 +1564,7 @@ func IsAutoRandomColumnID(tblInfo *model.TableInfo, colID int64) bool {
return false
}

func checkGeneratedColumn(ctx sessionctx.Context, colDefs []*ast.ColumnDef) error {
func checkGeneratedColumn(ctx sessionctx.Context, schemaName model.CIStr, tableName model.CIStr, colDefs []*ast.ColumnDef) error {
var colName2Generation = make(map[string]columnGenerationInDDL, len(colDefs))
var exists bool
var autoIncrementColumn string
Expand All @@ -1576,7 +1579,10 @@ func checkGeneratedColumn(ctx sessionctx.Context, colDefs []*ast.ColumnDef) erro
if containsColumnOption(colDef, ast.ColumnOptionAutoIncrement) {
exists, autoIncrementColumn = true, colDef.Name.Name.L
}
generated, depCols := findDependedColumnNames(colDef)
generated, depCols, err := findDependedColumnNames(schemaName, tableName, colDef)
if err != nil {
return errors.Trace(err)
}
if !generated {
colName2Generation[colDef.Name.Name.L] = columnGenerationInDDL{
position: i,
Expand Down Expand Up @@ -2094,7 +2100,7 @@ func CheckTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo
func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) {
// All of these rely on the AST structure of expressions, which were
// lost in the model (got serialized into strings).
if err := checkGeneratedColumn(ctx, s.Cols); err != nil {
if err := checkGeneratedColumn(ctx, s.Table.Schema, tbInfo.Name, s.Cols); err != nil {
return errors.Trace(err)
}

Expand Down Expand Up @@ -3672,7 +3678,10 @@ func CreateNewColumn(ctx sessionctx.Context, ti ast.Ident, schema *model.DBInfo,
return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Adding generated stored column through ALTER TABLE")
}

_, dependColNames := findDependedColumnNames(specNewColumn)
_, dependColNames, err := findDependedColumnNames(schema.Name, t.Meta().Name, specNewColumn)
if err != nil {
return nil, errors.Trace(err)
}
if !ctx.GetSessionVars().EnableAutoIncrementInGenerated {
if err := checkAutoIncrementRef(specNewColumn.Name.Name.L, dependColNames, t.Meta()); err != nil {
return nil, errors.Trace(err)
Expand Down Expand Up @@ -4485,7 +4494,7 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col
func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error {
var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutSchemaName
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)

var hasDefaultValue, setOnUpdateNow bool
Expand Down Expand Up @@ -4827,7 +4836,7 @@ func GetModifiableColumnJob(
}

// As same with MySQL, we don't support modifying the stored status for generated columns.
if err = checkModifyGeneratedColumn(sctx, t, col, newCol, specNewColumn, spec.Position); err != nil {
if err = checkModifyGeneratedColumn(sctx, schema.Name, t, col, newCol, specNewColumn, spec.Position); err != nil {
return nil, errors.Trace(err)
}

Expand Down Expand Up @@ -6306,7 +6315,7 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as

var sb strings.Builder
restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes |
format.RestoreSpacesAroundBinaryOperation
format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName
restoreCtx := format.NewRestoreCtx(restoreFlags, &sb)
sb.Reset()
err := idxPart.Expr.Restore(restoreCtx)
Expand Down
15 changes: 12 additions & 3 deletions ddl/generated_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,19 @@ func findPositionRelativeColumn(cols []*table.Column, pos *ast.ColumnPosition) (

// findDependedColumnNames returns a set of string, which indicates
// the names of the columns that are depended by colDef.
func findDependedColumnNames(colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}) {
func findDependedColumnNames(schemaName model.CIStr, tableName model.CIStr, colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}, err error) {
colsMap = make(map[string]struct{})
for _, option := range colDef.Options {
if option.Tp == ast.ColumnOptionGenerated {
generated = true
colNames := FindColumnNamesInExpr(option.Expr)
for _, depCol := range colNames {
if depCol.Schema.L != "" && schemaName.L != "" && depCol.Schema.L != schemaName.L {
return false, nil, dbterror.ErrWrongDBName.GenWithStackByArgs(depCol.Schema.O)
}
if depCol.Table.L != "" && tableName.L != "" && depCol.Table.L != tableName.L {
return false, nil, dbterror.ErrWrongTableName.GenWithStackByArgs(depCol.Table.O)
}
colsMap[depCol.Name.L] = struct{}{}
}
break
Expand Down Expand Up @@ -192,7 +198,7 @@ func (c *generatedColumnChecker) Leave(inNode ast.Node) (node ast.Node, ok bool)
// 3. check if the modified expr contains non-deterministic functions
// 4. check whether new column refers to any auto-increment columns.
// 5. check if the new column is indexed or stored
func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef, pos *ast.ColumnPosition) error {
func checkModifyGeneratedColumn(sctx sessionctx.Context, schemaName model.CIStr, tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef, pos *ast.ColumnPosition) error {
// rule 1.
oldColIsStored := !oldCol.IsGenerated() || oldCol.GeneratedStored
newColIsStored := !newCol.IsGenerated() || newCol.GeneratedStored
Expand Down Expand Up @@ -252,7 +258,10 @@ func checkModifyGeneratedColumn(sctx sessionctx.Context, tbl table.Table, oldCol
}

// rule 4.
_, dependColNames := findDependedColumnNames(newColDef)
_, dependColNames, err := findDependedColumnNames(schemaName, tbl.Meta().Name, newColDef)
if err != nil {
return errors.Trace(err)
}
if !sctx.GetSessionVars().EnableAutoIncrementInGenerated {
if err := checkAutoIncrementRef(newColDef.Name.Name.L, dependColNames, tbl.Meta()); err != nil {
return errors.Trace(err)
Expand Down
15 changes: 15 additions & 0 deletions parser/ast/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,21 @@ func TestDDLColumnOptionRestore(t *testing.T) {
runNodeRestoreTest(t, testCases, "CREATE TABLE child (id INT %s)", extractNodeFunc)
}

func TestGeneratedRestore(t *testing.T) {
testCases := []NodeRestoreTestCase{
{"generated always as(id + 1)", "GENERATED ALWAYS AS(`id`+1) VIRTUAL"},
{"generated always as(id + 1) virtual", "GENERATED ALWAYS AS(`id`+1) VIRTUAL"},
{"generated always as(id + 1) stored", "GENERATED ALWAYS AS(`id`+1) STORED"},
{"generated always as(lower(id)) stored", "GENERATED ALWAYS AS(LOWER(`id`)) STORED"},
{"generated always as(lower(child.id)) stored", "GENERATED ALWAYS AS(LOWER(`id`)) STORED"},
}
extractNodeFunc := func(node Node) Node {
return node.(*CreateTableStmt).Cols[0].Options[0]
}
runNodeRestoreTestWithFlagsStmtChange(t, testCases, "CREATE TABLE child (id INT %s)", extractNodeFunc,
format.DefaultRestoreFlags|format.RestoreWithoutSchemaName|format.RestoreWithoutTableName)
}

func TestDDLColumnDefRestore(t *testing.T) {
testCases := []NodeRestoreTestCase{
// for type
Expand Down
18 changes: 10 additions & 8 deletions parser/ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,17 @@ func (*TableName) resultSet() {}

// Restore implements Node interface.
func (n *TableName) restoreName(ctx *format.RestoreCtx) {
// restore db name
if n.Schema.String() != "" {
ctx.WriteName(n.Schema.String())
ctx.WritePlain(".")
} else if ctx.DefaultDB != "" {
// Try CTE, for a CTE table name, we shouldn't write the database name.
if !ctx.IsCTETableName(n.Name.L) {
ctx.WriteName(ctx.DefaultDB)
if !ctx.Flags.HasWithoutSchemaNameFlag() {
// restore db name
if n.Schema.String() != "" {
ctx.WriteName(n.Schema.String())
ctx.WritePlain(".")
} else if ctx.DefaultDB != "" {
// Try CTE, for a CTE table name, we shouldn't write the database name.
if !ctx.IsCTETableName(n.Name.L) {
ctx.WriteName(ctx.DefaultDB)
ctx.WritePlain(".")
}
}
}
// restore table name
Expand Down
4 changes: 2 additions & 2 deletions parser/ast/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,11 @@ type ColumnName struct {

// Restore implements Node interface.
func (n *ColumnName) Restore(ctx *format.RestoreCtx) error {
if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) {
if n.Schema.O != "" && !ctx.IsCTETableName(n.Table.L) && !ctx.Flags.HasWithoutSchemaNameFlag() {
ctx.WriteName(n.Schema.O)
ctx.WritePlain(".")
}
if n.Table.O != "" {
if n.Table.O != "" && !ctx.Flags.HasWithoutTableNameFlag() {
ctx.WriteName(n.Table.O)
ctx.WritePlain(".")
}
Expand Down
12 changes: 12 additions & 0 deletions parser/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ const (
RestoreTiDBSpecialComment
SkipPlacementRuleForRestore
RestoreWithTTLEnableOff
RestoreWithoutSchemaName
RestoreWithoutTableName
)

const (
Expand All @@ -247,6 +249,16 @@ func (rf RestoreFlags) has(flag RestoreFlags) bool {
return rf&flag != 0
}

// HasWithoutSchemaNameFlag returns a boolean indicating when `rf` has `RestoreWithoutSchemaName` flag.
func (rf RestoreFlags) HasWithoutSchemaNameFlag() bool {
return rf.has(RestoreWithoutSchemaName)
}

// HasWithoutTableNameFlag returns a boolean indicating when `rf` has `RestoreWithoutTableName` flag.
func (rf RestoreFlags) HasWithoutTableNameFlag() bool {
return rf.has(RestoreWithoutTableName)
}

// HasStringSingleQuotesFlag returns a boolean indicating when `rf` has `RestoreStringSingleQuotes` flag.
func (rf RestoreFlags) HasStringSingleQuotesFlag() bool {
return rf.has(RestoreStringSingleQuotes)
Expand Down