Skip to content

Commit

Permalink
support both case sensitive and insensitive tables for MySQL and Post…
Browse files Browse the repository at this point in the history
…gres, fix #286
  • Loading branch information
qianyiwen2019 committed Dec 13, 2024
1 parent 266b0eb commit 90637ef
Show file tree
Hide file tree
Showing 69 changed files with 1,187 additions and 138 deletions.
124 changes: 85 additions & 39 deletions dt-common/src/meta/ddl_meta/ddl_parser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{config::config_enums::DbType, error::Error};
use crate::{config::config_enums::DbType, error::Error, utils::sql_util::SqlUtil};
use anyhow::bail;
use nom::{
branch::alt,
Expand All @@ -7,7 +7,7 @@ use nom::{
complete::{multispace0, multispace1},
is_alphanumeric,
},
combinator::{map, not, opt, peek},
combinator::{map, not, opt, peek, recognize},
multi::many1,
sequence::{delimited, pair, preceded, tuple},
IResult,
Expand Down Expand Up @@ -106,7 +106,7 @@ impl DdlParser {
))(i)?;

let statement = CreateDatabaseStatement {
db: String::from_utf8_lossy(database).to_string(),
db: self.identifier_to_string(database),
if_not_exists: if_not_exists.is_some(),
unparsed: to_string(remaining_input),
};
Expand All @@ -131,7 +131,7 @@ impl DdlParser {
))(i)?;

let statement = DropDatabaseStatement {
db: String::from_utf8_lossy(database).to_string(),
db: self.identifier_to_string(database),
if_exists: if_exists.is_some(),
unparsed: to_string(remaining_input),
};
Expand All @@ -155,7 +155,7 @@ impl DdlParser {
))(i)?;

let statement = AlterDatabaseStatement {
db: String::from_utf8_lossy(database).to_string(),
db: self.identifier_to_string(database),
unparsed: to_string(remaining_input),
};

Expand All @@ -179,7 +179,7 @@ impl DdlParser {
))(i)?;

let statement = CreateSchemaStatement {
schema: String::from_utf8_lossy(schema).to_string(),
schema: self.identifier_to_string(schema),
if_not_exists: if_not_exists.is_some(),
unparsed: to_string(remaining_input),
};
Expand All @@ -204,7 +204,7 @@ impl DdlParser {
))(i)?;

let statement = DropSchemaStatement {
schema: String::from_utf8_lossy(schema).to_string(),
schema: self.identifier_to_string(schema),
if_exists: if_exists.is_some(),
unparsed: to_string(remaining_input),
};
Expand All @@ -228,7 +228,7 @@ impl DdlParser {
))(i)?;

let statement = AlterSchemaStatement {
schema: String::from_utf8_lossy(schema).to_string(),
schema: self.identifier_to_string(schema),
unparsed: to_string(remaining_input),
};

Expand Down Expand Up @@ -261,7 +261,7 @@ impl DdlParser {
))(i)?;

// temporary tables won't be in binlog
let (schema, tb) = parse_table(table);
let (schema, tb) = self.parse_table(table);
let statement = MysqlCreateTableStatement {
db: schema,
tb,
Expand Down Expand Up @@ -316,7 +316,7 @@ impl DdlParser {
))(i)?;

// temporary tables won't be in binlog
let (schema, tb) = parse_table(table);
let (schema, tb) = self.parse_table(table);
let statement = PgCreateTableStatement {
schema,
tb,
Expand Down Expand Up @@ -348,7 +348,7 @@ impl DdlParser {

let mut schema_tbs = Vec::new();
for table in table_list {
schema_tbs.push(parse_table(table))
schema_tbs.push(self.parse_table(table))
}

let statement = DropMultiTableStatement {
Expand Down Expand Up @@ -386,7 +386,7 @@ impl DdlParser {
|i| self.schema_table(i),
multispace0,
))(i)?;
Ok((remaining_input, parse_table(new_table)))
Ok((remaining_input, self.parse_table(new_table)))
};

let (remaining_input, (_, _, _, _, table, _, rename_to, _)) = tuple((
Expand All @@ -400,7 +400,7 @@ impl DdlParser {
multispace0,
))(i)?;

let (db, tb) = parse_table(table);
let (db, tb) = self.parse_table(table);
if let Some((new_db, new_tb)) = rename_to {
let statement = MysqlAlterTableRenameStatement {
db,
Expand Down Expand Up @@ -441,7 +441,7 @@ impl DdlParser {
|i| self.schema_table(i),
multispace0,
))(i)?;
Ok((remaining_input, parse_table(new_table)))
Ok((remaining_input, self.parse_table(new_table)))
};

let set_schema = |i: &'a [u8]| -> IResult<&'a [u8], String> {
Expand All @@ -453,7 +453,7 @@ impl DdlParser {
|i| self.sql_identifier(i),
multispace0,
))(i)?;
Ok((remaining_input, to_string(new_schema)))
Ok((remaining_input, self.identifier_to_string(new_schema)))
};

let (
Expand All @@ -473,7 +473,7 @@ impl DdlParser {
multispace0,
))(i)?;

let (schema, tb) = parse_table(table);
let (schema, tb) = self.parse_table(table);
if let Some((new_schema, new_tb)) = rename_to_res {
let statement = PgAlterTableRenameStatement {
schema,
Expand Down Expand Up @@ -543,7 +543,7 @@ impl DdlParser {
multispace0,
))(i)?;

let (db, tb) = parse_table(table);
let (db, tb) = self.parse_table(table);
let statement = MysqlTruncateTableStatement {
db,
tb,
Expand All @@ -570,7 +570,7 @@ impl DdlParser {
multispace0,
))(i)?;

let (schema, tb) = parse_table(table);
let (schema, tb) = self.parse_table(table);
let statement = PgTruncateTableStatement {
schema,
tb,
Expand Down Expand Up @@ -599,8 +599,8 @@ impl DdlParser {
let mut schema_tbs = Vec::new();
let mut new_schema_tbs = Vec::new();
for (from, to) in table_to_table_list {
let from = parse_table(from);
let to = parse_table(to);
let from = self.parse_table(from);
let to = self.parse_table(to);
schema_tbs.push(from);
new_schema_tbs.push(to);
}
Expand Down Expand Up @@ -657,7 +657,7 @@ impl DdlParser {
multispace0,
))(i)?;

let (db, tb) = parse_table(table);
let (db, tb) = self.parse_table(table);
let index_kind_str = if let Some((index_kind, _)) = index_kind {
Some(to_string(index_kind))
} else {
Expand All @@ -674,7 +674,7 @@ impl DdlParser {
tb,
index_kind: index_kind_str,
index_type: index_type_str,
index_name: to_string(index_name),
index_name: self.identifier_to_string(index_name),
unparsed: to_string(remaining_input),
};

Expand Down Expand Up @@ -709,12 +709,12 @@ impl DdlParser {
))(i)?;

let (if_not_exists, index_name) = if let Some(name) = name {
(name.0.is_some(), Some(to_string(name.1)))
(name.0.is_some(), Some(self.identifier_to_string(name.1)))
} else {
(false, None)
};

let (schema, tb) = parse_table(table);
let (schema, tb) = self.parse_table(table);
let statement = PgCreateIndexStatement {
schema,
tb,
Expand Down Expand Up @@ -757,11 +757,11 @@ impl DdlParser {
multispace0,
))(i)?;

let (db, tb) = parse_table(table);
let (db, tb) = self.parse_table(table);
let statement = MysqlDropIndexStatement {
db,
tb,
index_name: to_string(index_name),
index_name: self.identifier_to_string(index_name),
unparsed: to_string(remaining_input),
};

Expand Down Expand Up @@ -789,7 +789,7 @@ impl DdlParser {

let mut index_names: Vec<String> = Vec::new();
for name in index_name_list.iter() {
index_names.push(to_string(name));
index_names.push(self.identifier_to_string(name));
}

let statement = PgDropMultiIndexStatement {
Expand Down Expand Up @@ -860,11 +860,18 @@ impl DdlParser {
not(peek(|i| self.sql_keyword(i))),
take_while1(is_sql_identifier),
),
delimited(
// delimited(
// tag("\""),
// take_while1(is_escaped_sql_identifier_2),
// tag("\""),
// );

// keep tag("\""), input: "Abc", return: "Abc"
recognize(tuple((
tag("\""),
take_while1(is_escaped_sql_identifier_2),
tag("\""),
),
))),
))(i);
}

Expand All @@ -873,6 +880,7 @@ impl DdlParser {
not(peek(|i| self.sql_keyword(i))),
take_while1(is_sql_identifier),
),
// remove tag("`"), input: `Abc``, return: Abc
delimited(tag("`"), take_while1(is_escaped_sql_identifier_1), tag("`")),
))(i)
}
Expand Down Expand Up @@ -900,6 +908,32 @@ impl DdlParser {
keyword_s_to_z,
))(i)
}

fn parse_table(&self, table: (Option<Vec<u8>>, Vec<u8>)) -> (String, String) {
let schema = if let Some(schema_raw) = &table.0 {
self.identifier_to_string(schema_raw)
} else {
String::new()
};
let tb = self.identifier_to_string(&table.1);
(schema, tb)
}

fn identifier_to_string(&self, i: &[u8]) -> String {
let identifier = to_string(i);
if self.db_type == DbType::Pg {
// In PostgreSQL, Identifiers (including column names) that are not double-quoted are folded to lower case.
// Identifiers created with double quotes retain upper case letters
let escape_pair = SqlUtil::get_escape_pairs(&self.db_type)[0];
if SqlUtil::is_escaped(&identifier, &escape_pair) {
SqlUtil::unescape(&identifier, &escape_pair)
} else {
identifier.to_lowercase()
}
} else {
identifier
}
}
}

#[inline]
Expand Down Expand Up @@ -943,16 +977,6 @@ fn ws_sep_comma(i: &[u8]) -> IResult<&[u8], &[u8]> {
delimited(multispace0, tag(","), multispace0)(i)
}

fn parse_table(table: (Option<Vec<u8>>, Vec<u8>)) -> (String, String) {
let schema = if let Some(schema_raw) = &table.0 {
to_string(schema_raw)
} else {
String::new()
};
let tb = to_string(&table.1);
(schema, tb)
}

fn to_string(i: &[u8]) -> String {
String::from_utf8_lossy(i).to_string()
}
Expand Down Expand Up @@ -1598,6 +1622,28 @@ mod test_pg {
}
}

#[test]
fn test_create_table_with_schema_with_upper_case_pg() {
let sqls = [
r#"CREATE TABLE IF NOT EXISTS Test_DB.Test_TB(id int, "Value" int);"#,
r#"CREATE TABLE IF NOT EXISTS "Test_DB".Test_TB(id int, "Value" int);"#,
r#"CREATE TABLE IF NOT EXISTS "Test_DB"."Test_TB"(id int, "Value" int);"#,
];

let expect_sqls = [
r#"CREATE TABLE IF NOT EXISTS "test_db"."test_tb" (id int, "Value" int);"#,
r#"CREATE TABLE IF NOT EXISTS "Test_DB"."test_tb" (id int, "Value" int);"#,
r#"CREATE TABLE IF NOT EXISTS "Test_DB"."Test_TB" (id int, "Value" int);"#,
];

let parser = DdlParser::new(DbType::Pg);
for i in 0..sqls.len() {
let r = parser.parse(sqls[i]).unwrap();
assert_eq!(r.ddl_type, DdlType::CreateTable);
assert_eq!(r.to_sql(), expect_sqls[i]);
}
}

#[test]
fn test_create_table_with_schema_with_special_characters_pg() {
let sqls = [
Expand Down
14 changes: 9 additions & 5 deletions dt-common/src/meta/ddl_meta/ddl_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,11 +747,11 @@ impl PgDropMultiIndexStatement {
}

fn append_tb(sql: &str, schema: &str, tb: &str, db_type: &DbType) -> String {
let tb = SqlUtil::escape_by_db_type(tb, db_type);
let tb = escape_identifier(tb, db_type);
if schema.is_empty() {
format!("{} {}", sql, tb)
} else {
let schema = SqlUtil::escape_by_db_type(schema, db_type);
let schema = escape_identifier(schema, db_type);
format!("{} {}.{}", sql, schema, tb)
}
}
Expand All @@ -770,11 +770,11 @@ fn append_identifier(
with_white_space: bool,
db_type: &DbType,
) -> String {
let escaped_identifier = SqlUtil::escape_by_db_type(identifier, db_type);
let identifier = escape_identifier(identifier, db_type);
if with_white_space {
format!("{} {}", sql, escaped_identifier)
format!("{} {}", sql, identifier)
} else {
format!("{}{}", sql, escaped_identifier)
format!("{}{}", sql, identifier)
}
}

Expand All @@ -784,3 +784,7 @@ fn append_unparsed(sql: String, unparsed: &str) -> String {
}
sql
}

fn escape_identifier(identifier: &str, db_type: &DbType) -> String {
SqlUtil::escape_by_db_type(identifier, db_type)
}
Loading

0 comments on commit 90637ef

Please sign in to comment.