Skip to content

Commit 5acf810

Browse files
authored
Fix two regexes related to SQL create statement parsing (#140)
* fix: use \s in tableRegexp to match multiline create statements * fix: improve tableReg regex to correctly match table name * tests: add test cases for different spacing cases
1 parent 3aa841d commit 5acf810

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

ddlmod.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
var (
1515
sqliteSeparator = "`|\"|'|\t"
1616
indexRegexp = regexp.MustCompile(fmt.Sprintf("(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\\w\\d-]+[%v]? ON (.*)$", sqliteSeparator, sqliteSeparator))
17-
tableRegexp = regexp.MustCompile(fmt.Sprintf("(?is)(CREATE TABLE [%v]?[\\w\\d-]+[%v]?)(?: \\((.*)\\))?", sqliteSeparator, sqliteSeparator))
17+
tableRegexp = regexp.MustCompile(fmt.Sprintf("(?is)(CREATE TABLE [%v]?[\\w\\d-]+[%v]?)(?:\\s*\\((.*)\\))?", sqliteSeparator, sqliteSeparator))
1818
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
1919
columnsRegexp = regexp.MustCompile(fmt.Sprintf("[(,][%v]?(\\w+)[%v]?", sqliteSeparator, sqliteSeparator))
2020
columnRegexp = regexp.MustCompile(fmt.Sprintf("^[%v]?([\\w\\d]+)[%v]?\\s+([\\w\\(\\)\\d]+)(.*)$", sqliteSeparator, sqliteSeparator))

ddlmod_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,75 @@ func TestParseDDL(t *testing.T) {
9898
}
9999
}
100100

101+
func TestParseDDL_Whitespaces(t *testing.T) {
102+
testColumns := []migrator.ColumnType{
103+
{
104+
NameValue: sql.NullString{String: "id", Valid: true},
105+
DataTypeValue: sql.NullString{String: "integer", Valid: true},
106+
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
107+
NullableValue: sql.NullBool{Bool: false, Valid: true},
108+
DefaultValueValue: sql.NullString{Valid: false},
109+
UniqueValue: sql.NullBool{Bool: true, Valid: true},
110+
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
111+
},
112+
{
113+
NameValue: sql.NullString{String: "dark_mode", Valid: true},
114+
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
115+
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
116+
NullableValue: sql.NullBool{Valid: true},
117+
DefaultValueValue: sql.NullString{String: "true", Valid: true},
118+
UniqueValue: sql.NullBool{Bool: false, Valid: true},
119+
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
120+
},
121+
}
122+
123+
params := []struct {
124+
name string
125+
sql []string
126+
nFields int
127+
columns []migrator.ColumnType
128+
}{
129+
{
130+
"with_newline",
131+
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
132+
2,
133+
testColumns,
134+
},
135+
{
136+
"with_newline_2",
137+
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
138+
2,
139+
testColumns,
140+
},
141+
{
142+
"with_missing_space",
143+
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
144+
2,
145+
testColumns,
146+
},
147+
{
148+
"with_many_spaces",
149+
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
150+
2,
151+
testColumns,
152+
},
153+
}
154+
for _, p := range params {
155+
t.Run(p.name, func(t *testing.T) {
156+
ddl, err := parseDDL(p.sql...)
157+
158+
if err != nil {
159+
panic(err.Error())
160+
}
161+
162+
if len(ddl.fields) != p.nFields {
163+
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
164+
}
165+
tests.AssertEqual(t, ddl.columns, p.columns)
166+
})
167+
}
168+
}
169+
101170
func TestParseDDL_error(t *testing.T) {
102171
params := []struct {
103172
name string

migrator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
390390
return nil
391391
}
392392

393-
tableReg, err := regexp.Compile(" ('|`|\"| )" + table + "('|`|\"| ) ")
393+
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
394394
if err != nil {
395395
return err
396396
}

0 commit comments

Comments
 (0)