Skip to content

Commit

Permalink
fix json类型字符集检查, fix #7
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchuanchuan committed Mar 10, 2019
1 parent cf3a3b4 commit 26c6fca
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 43 deletions.
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ type Inc struct {
EnableAutoIncrementUnsigned bool `toml:"enable_autoincrement_unsigned" json:"enable_autoincrement_unsigned"`
EnableBlobType bool `toml:"enable_blob_type" json:"enable_blob_type"`
EnableColumnCharset bool `toml:"enable_column_charset" json:"enable_column_charset"`
EnableDropDatabase bool `toml:"enable_drop_database" json:"enable_drop_database"`
EnableDropTable bool `toml:"enable_drop_table" json:"enable_drop_table"` // 允许删除表
EnableEnumSetBit bool `toml:"enable_enum_set_bit" json:"enable_enum_set_bit"`
EnableForeignKey bool `toml:"enable_foreign_key" json:"enable_foreign_key"`
Expand Down
3 changes: 3 additions & 0 deletions session/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ const (
ER_REMOVED_SPACES
ER_CHANGE_COLUMN_TYPE
ER_CANT_DROP_TABLE
ER_CANT_DROP_DATABASE
ER_WRONG_TABLE_NAME
ER_CANT_SET_CHARSET
ER_CANT_SET_COLLATION
Expand Down Expand Up @@ -317,6 +318,7 @@ var MyErrors = map[int]string{
ER_REMOVED_SPACES: "Leading spaces are removed from name '%s'",
ER_CHANGE_COLUMN_TYPE: "类型转换警告: 列 '%s' %s -> %s.",
ER_CANT_DROP_TABLE: "禁用【DROP】|【TRUNCATE】删除/清空表 '%s', 请改用RENAME重写.",
ER_CANT_DROP_DATABASE: "命令禁止! 无法删除数据库'%s'.",
ER_WRONG_TABLE_NAME: "Incorrect table name '%-.100s'",
ER_CANT_SET_CHARSET: "禁止指定字符集: '%s'",
ER_CANT_SET_COLLATION: "禁止指定排序规则: '%s'",
Expand Down Expand Up @@ -425,6 +427,7 @@ func GetErrorLevel(errorNo int) uint8 {
ER_VIEW_SELECT_CLAUSE,
ER_NOT_SUPPORTED_ITEM_TYPE,
ER_CANT_DROP_TABLE,
ER_CANT_DROP_DATABASE,
ER_CANT_DROP_FIELD_OR_KEY,
ER_NOT_SUPPORTED_YET,
ER_TABLE_MUST_INNODB,
Expand Down
57 changes: 45 additions & 12 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,33 @@ func (s *session) mysqlShowCreateTable(t *TableInfo) {
}
}

// mysqlShowCreateDatabase 生成回滚语句
func (s *session) mysqlShowCreateDatabase(name string) {

sql := fmt.Sprintf("SHOW CREATE DATABASE `%s`;", name)

var res string

rows, err := s.db.Raw(sql).Rows()
if rows != nil {
defer rows.Close()
}

if err != nil {
if myErr, ok := err.(*mysqlDriver.MySQLError); ok {
s.AppendErrorMessage(myErr.Message)
} else {
s.AppendErrorMessage(err.Error())
}
} else if rows != nil {
for rows.Next() {
rows.Scan(&res, &res)
}
s.myRecord.DDLRollback = res
s.myRecord.DDLRollback += ";"
}
}

func (s *session) checkRenameTable(node *ast.RenameTableStmt, sql string) {

log.Debug("checkRenameTable")
Expand Down Expand Up @@ -2531,7 +2558,9 @@ func (s *session) mysqlCheckField(t *TableInfo, field *ast.ColumnDef) {
}

if field.Tp.Charset != "" || field.Tp.Collate != "" {
s.AppendErrorNo(ER_CHARSET_ON_COLUMN, tableName, field.Name.Name)
if field.Tp.Charset != "binary" {
s.AppendErrorNo(ER_CHARSET_ON_COLUMN, tableName, field.Name.Name)
}
}
}

Expand Down Expand Up @@ -3048,8 +3077,8 @@ func (s *session) checkDBExists(db string, reportNotExists bool) bool {
return false
}

if _, ok := s.dbCacheList[db]; ok {
return true
if v, ok := s.dbCacheList[strings.ToLower(db)]; ok {
return v
}

sql := "show databases like '%s';"
Expand Down Expand Up @@ -3079,7 +3108,7 @@ func (s *session) checkDBExists(db string, reportNotExists bool) bool {
}
return false
} else {
s.dbCacheList[db] = true
s.dbCacheList[strings.ToLower(db)] = true
return true
}

Expand Down Expand Up @@ -3292,22 +3321,26 @@ func (s *session) checkInsert(node *ast.InsertStmt, sql string) {
}

func (s *session) checkDropDB(node *ast.DropDatabaseStmt) {

log.Debug("checkDropDB")

// log.Infof("%#v \n", node)
if !s.Inc.EnableDropDatabase {
s.AppendErrorNo(ER_CANT_DROP_DATABASE, node.Name)
return
}

s.AppendErrorMessage(fmt.Sprintf("命令禁止! 无法删除数据库'%s'.", node.Name))
if s.checkDBExists(node.Name, !node.IfExists) {
if s.opt.execute {
// 生成回滚语句
s.mysqlShowCreateDatabase(node.Name)
}
s.dbCacheList[strings.ToLower(node.Name)] = false
}
}

func (s *session) executeInceptionSet(node *ast.InceptionSetStmt, sql string) ([]ast.RecordSet, error) {

log.Debug("executeInceptionSet")

// IsGlobal bool
// IsSystem bool
for _, v := range node.Variables {

if !v.IsSystem {
return nil, errors.New("无效参数")
}
Expand Down Expand Up @@ -3771,7 +3804,7 @@ func (s *session) checkCreateDB(node *ast.CreateDatabaseStmt) {
return
}

s.dbCacheList[node.Name] = true
s.dbCacheList[strings.ToLower(node.Name)] = true

if s.opt.execute {
s.myRecord.DDLRollback = fmt.Sprintf("DROP DATABASE `%s`;", node.Name)
Expand Down
18 changes: 9 additions & 9 deletions session/session_inception_exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -832,11 +832,11 @@ func (s *testSessionIncExecSuite) TestAlterTableDropColumn(c *C) {
c.Assert(row[2], Equals, "0")

// // drop column
sql = "create table t1(id int null);alter table t2 drop c1"
sql = "drop table if exists t1;create table t1(id int null);alter table t1 drop c1"
s.testErrorCode(c, sql,
session.NewErr(session.ER_COLUMN_NOT_EXISTED, "t2.c1"))
session.NewErr(session.ER_COLUMN_NOT_EXISTED, "t1.c1"))

sql = "create table t1(id int null);alter table t2 drop id;"
sql = "drop table if exists t1;create table t1(id int null);alter table t1 drop id;"
s.testErrorCode(c, sql,
session.NewErr(session.ErrCantRemoveAllFields))
}
Expand Down Expand Up @@ -1286,22 +1286,22 @@ func (s *testSessionIncExecSuite) TestSetVariables(c *C) {
result = tk.MustQueryInc("inception show variables like 'ghost_default_retries';")
result.Check(testkit.Rows("ghost_default_retries 70"))

tk.MustExecInc("inception set osc_max_running = 100;")
result = tk.MustQueryInc("inception show variables like 'osc_max_running';")
result.Check(testkit.Rows("osc_max_running 100"))
tk.MustExecInc("inception set osc_max_thread_running = 100;")
result = tk.MustQueryInc("inception show variables like 'osc_max_thread_running';")
result.Check(testkit.Rows("osc_max_thread_running 100"))

// 无效参数
res, err := tk.ExecInc("inception set osc_max_running1 = 100;")
res, err := tk.ExecInc("inception set osc_max_thread_running1 = 100;")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "无效参数")
if res != nil {
c.Assert(res.Close(), IsNil)
}

// 无效参数
res, err = tk.ExecInc("inception set osc_max_running = 'abc';")
res, err = tk.ExecInc("inception set osc_max_thread_running = 'abc';")
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, "[variable:1232]Incorrect argument type to variable 'osc_max_running'")
c.Assert(err.Error(), Equals, "[variable:1232]Incorrect argument type to variable 'osc_max_thread_running'")
if res != nil {
c.Assert(res.Close(), IsNil)
}
Expand Down
79 changes: 57 additions & 22 deletions session/session_inception_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,42 @@ func (s *testSessionIncSuite) TearDownTest(c *C) {
c.Skip("skipping test; in TRAVIS mode")
}

tk := testkit.NewTestKitWithInit(c, s.store)
r := tk.MustQuery("show tables")
for _, tb := range r.Rows() {
tableName := tb[0]
tk.MustExec(fmt.Sprintf("drop table %v", tableName))
if s.tk == nil {
s.tk = testkit.NewTestKitWithInit(c, s.store)
}

saved := config.GetGlobalConfig().Inc
defer func() {
config.GetGlobalConfig().Inc = saved
}()

config.GetGlobalConfig().Inc.EnableDropTable = true

res := makeSql(s.tk, "show tables")
c.Assert(int(s.tk.Se.AffectedRows()), Equals, 2)

row := res.Rows()[int(s.tk.Se.AffectedRows())-1]
sql := row[5]

exec := `/*--user=admin;--password=han123;--host=127.0.0.1;--execute=1;--backup=0;--port=3306;--enable-ignore-warnings;*/
inception_magic_start;
use test_inc;
%s;
inception_magic_commit;`
for _, name := range strings.Split(sql.(string), "\n") {
if strings.HasPrefix(name, "show tables:") {
continue
}
n := strings.Replace(name, "'", "", -1)
res := s.tk.MustQueryInc(fmt.Sprintf(exec, "drop table "+n))
// fmt.Println(res.Rows())
c.Assert(int(s.tk.Se.AffectedRows()), Equals, 2)
row := res.Rows()[int(s.tk.Se.AffectedRows())-1]
c.Assert(row[2], Equals, "0")
c.Assert(row[3], Equals, "Execute Successfully")
c.Assert(row[4], IsNil, row[4])
}

}

func makeSql(tk *testkit.TestKit, sql string) *testkit.Result {
Expand Down Expand Up @@ -649,6 +679,10 @@ func (s *testSessionIncSuite) TestAlterTableAddColumn(c *C) {
sql = "create table t1(c2 int on update current_timestamp);"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_ON_UPDATE, "c2"))

sql = "create table t1 (c1 int primary key);alter table t1 add c2 json;"
s.testErrorCode(c, sql)

}

func (s *testSessionIncSuite) TestAlterTableAlterColumn(c *C) {
Expand Down Expand Up @@ -1108,15 +1142,12 @@ func (s *testSessionIncSuite) TestDelete(c *C) {
}

func (s *testSessionIncSuite) TestCreateDataBase(c *C) {
saved := config.GetGlobalConfig().Inc
defer func() {
config.GetGlobalConfig().Inc = saved
}()

sql = "drop database if exists test1111111111111111111;create database test1111111111111111111;"
s.testErrorCode(c, sql)

// 存在
sql = "create database test1111111111111111111;create database test1111111111111111111;"
s.testErrorCode(c, sql,
session.NewErrf("数据库'test1111111111111111111'已存在."))

config.GetGlobalConfig().Inc.EnableDropDatabase = false
// 不存在
sql = "drop database if exists test1111111111111111111;"
s.testErrorCode(c, sql,
Expand All @@ -1125,15 +1156,19 @@ func (s *testSessionIncSuite) TestCreateDataBase(c *C) {
sql = "drop database test1111111111111111111;"
s.testErrorCode(c, sql,
session.NewErrf("命令禁止! 无法删除数据库'test1111111111111111111'."))
config.GetGlobalConfig().Inc.EnableDropDatabase = true

// if not exists 创建
sql = "create database if not exists test1111111111111111111;create database if not exists test1111111111111111111;"
sql = "drop database if exists test1111111111111111111;create database test1111111111111111111;"
s.testErrorCode(c, sql)

// if not exists 删除
sql = "drop database if exists test1111111111111111111;drop database if exists test1111111111111111111;"
// 存在
sql = "create database test1111111111111111111;create database test1111111111111111111;"
s.testErrorCode(c, sql,
session.NewErrf("命令禁止! 无法删除数据库'test1111111111111111111'."))
session.NewErrf("数据库'test1111111111111111111'已存在."))

// if not exists 创建
sql = "create database if not exists test1111111111111111111;create database if not exists test1111111111111111111;"
s.testErrorCode(c, sql)

// create database
sql := "create database aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
Expand All @@ -1147,24 +1182,24 @@ func (s *testSessionIncSuite) TestCreateDataBase(c *C) {
// 字符集
config.GetGlobalConfig().Inc.EnableSetCharset = false
config.GetGlobalConfig().Inc.SupportCharset = ""
sql = "create database test1 character set utf8;"
sql = "drop database test1;create database test1 character set utf8;"
s.testErrorCode(c, sql,
session.NewErr(session.ER_CANT_SET_CHARSET, "utf8"))

config.GetGlobalConfig().Inc.SupportCharset = "utf8mb4"
sql = "create database test1 character set utf8;"
sql = "drop database test1;create database test1 character set utf8;"
s.testErrorCode(c, sql,
session.NewErr(session.ER_CANT_SET_CHARSET, "utf8"),
session.NewErr(session.ER_NAMES_MUST_UTF8, "utf8mb4"))

config.GetGlobalConfig().Inc.EnableSetCharset = true
config.GetGlobalConfig().Inc.SupportCharset = "utf8,utf8mb4"
sql = "create database test1 character set utf8;"
sql = "drop database test1;create database test1 character set utf8;"
s.testErrorCode(c, sql)

config.GetGlobalConfig().Inc.EnableSetCharset = true
config.GetGlobalConfig().Inc.SupportCharset = "utf8,utf8mb4"
sql = "create database test1 character set laitn1;"
sql = "drop database test1;create database test1 character set laitn1;"
s.testErrorCode(c, sql,
session.NewErr(session.ER_NAMES_MUST_UTF8, "utf8,utf8mb4"))
}
Expand Down

0 comments on commit 26c6fca

Please sign in to comment.