-
Notifications
You must be signed in to change notification settings - Fork 5
[feat] introduce classifysql package #284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,358 @@ | ||
| package classifysql | ||
|
|
||
| import ( | ||
| "errors" | ||
| "regexp" | ||
| "strings" | ||
|
|
||
| "github.com/muir/sqltoken" | ||
| ) | ||
|
|
||
| // Dialect enumerates supported SQL dialects. | ||
| type Dialect int | ||
|
|
||
| const ( | ||
| DialectInvalid Dialect = iota | ||
| DialectMySQL | ||
| DialectPostgres | ||
| DialectSingleStore = DialectMySQL | ||
| ) | ||
|
|
||
| // Flag is a bitmask describing statement / script properties. | ||
| type Flag uint32 | ||
|
|
||
| // Individual flag bits. | ||
| const ( | ||
| IsDDL Flag = 1 << iota | ||
| IsDML | ||
| IsNonIdempotent | ||
| IsMultipleStatements | ||
| IsEasilyIdempotentFix | ||
| IsMustNonTx // Statement must run outside a transaction (e.g. PG CREATE INDEX CONCURRENTLY) | ||
| ) | ||
|
|
||
| var flagNameMap = map[Flag]string{ | ||
| IsDDL: "DDL", | ||
| IsDML: "DML", | ||
| IsNonIdempotent: "NonIdem", | ||
| IsMultipleStatements: "Multi", | ||
| IsEasilyIdempotentFix: "EasyFix", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's cool repo, |
||
| IsMustNonTx: "MustNonTx", | ||
| } | ||
|
|
||
| var flagsInOrder = []Flag{IsDDL, IsDML, IsNonIdempotent, IsEasilyIdempotentFix, IsMustNonTx, IsMultipleStatements} | ||
|
|
||
| // Statement represents one SQL statement with its original, unstripped tokens and classification flags. | ||
| type Statement struct { | ||
| Flags Flag | ||
| Tokens sqltoken.Tokens // unstripped tokens for this statement | ||
| dialect Dialect | ||
| version int | ||
| stripped sqltoken.Tokens // cached stripped tokens (nil/empty if none) | ||
| strippedString string // cached stripped string (empty if none), lower case | ||
| firstLower string // cached lowercase first token text (empty if none) | ||
| } | ||
|
|
||
| // Statements is a slice of Statement. | ||
| type Statements []Statement | ||
|
|
||
| // Postgres statements that inherently require execution outside a transaction. | ||
| // Lower-cased, whitespace-collapsed prefixes. | ||
| var pgMustNonTxPrefixes = []string{ | ||
| "create index concurrently", | ||
| "create unique index concurrently", | ||
| "drop index concurrently", | ||
| "refresh materialized view concurrently", | ||
| "reindex concurrently", | ||
| "vacuum full", | ||
| "cluster", | ||
| "create database", | ||
| "drop database", | ||
| "create tablespace", | ||
| "drop tablespace", | ||
| "create subscription", | ||
| "alter subscription", | ||
| "drop subscription", | ||
| } | ||
|
|
||
| var ( | ||
| mysqlCreateEasy = []string{"create table"} | ||
| mysqlDropEasy = []string{"drop table", "drop database"} | ||
| pgCreateEasy = []string{"create table", "create index", "create schema", "create sequence", "create extension"} | ||
| pgDropEasy = []string{"drop table", "drop index", "drop schema", "drop sequence", "drop extension"} | ||
| ) | ||
|
|
||
| // ClassifyTokens tokenizes and classifies a SQL script, returning per-statement flags and original tokens. | ||
| // Only error case is invalid dialect. | ||
| // strings.Join(ClassifyTokens().TokensList.Strings, ";") should return the original sqlString | ||
| func ClassifyTokens(d Dialect, majorVersion int, sqlString string) (Statements, error) { | ||
| if strings.TrimSpace(sqlString) == "" { | ||
| return nil, nil | ||
| } | ||
| var tokens sqltoken.Tokens | ||
| switch d { | ||
| case DialectPostgres: | ||
| // TokenizePostgreSQL preserves comments/whitespace in tokens | ||
| tokens = sqltoken.TokenizePostgreSQL(sqlString) | ||
| case DialectMySQL: | ||
| // TokenizeMySQL preserves comments/whitespace in tokens | ||
| tokens = sqltoken.TokenizeMySQL(sqlString) | ||
| default: | ||
| return nil, errors.New("invalid dialect") | ||
| } | ||
| split := tokens.CmdSplitUnstripped() | ||
| stmts := make(Statements, len(split)) | ||
| for i, raw := range split { | ||
| stmts[i].Tokens = raw | ||
| stmts[i].dialect = d | ||
| stmts[i].version = majorVersion | ||
| stripped := raw.Strip() | ||
| stmts[i].stripped = stripped | ||
| if len(stripped) == 0 { | ||
| // just comments/whitespace | ||
| continue | ||
| } | ||
| stmts[i].strippedString = strings.ToLower(stripped.String()) | ||
| stmts[i].firstLower = strings.ToLower(stripped[0].Text) | ||
| stmts[i].classify() | ||
| } | ||
| // Add IsMultipleStatements if there are multiple real statements | ||
| if stmts.CountNonEmpty() > 1 { | ||
| for i := range stmts { | ||
| stmts[i].Flags |= IsMultipleStatements | ||
| } | ||
| } | ||
| return stmts, nil | ||
| } | ||
|
|
||
| // StripString returns the lowercased stripped string | ||
| func (s Statement) StripString() string { return s.strippedString } | ||
|
|
||
| // FirstWord returns the lower case first word | ||
| func (s Statement) FirstWord() string { return s.firstLower } | ||
|
|
||
| func (s Statement) Strip() sqltoken.Tokens { return s.stripped } | ||
|
|
||
| // classifyFirstVerb applies verb-based classification logic | ||
| func (s *Statement) classify() { | ||
| switch s.firstLower { | ||
| case "create": | ||
| s.Flags = s.classifyCreate() | ||
| case "alter": | ||
| s.Flags = s.classifyAlter() | ||
| case "drop": | ||
| s.Flags = s.classifyDrop() | ||
| case "rename", "comment": | ||
| s.Flags = IsDDL | IsNonIdempotent | ||
| case "truncate": | ||
| s.Flags = IsDDL | ||
| case "insert", "update", "delete", "replace", "call", "do", "load", "handler", "import", "with": | ||
| s.Flags = IsDML | ||
| } | ||
| if s.dialect == DialectPostgres && isPostgresMustNonTxVersion(s.strippedString, s.version) { | ||
| s.Flags |= IsMustNonTx | ||
| } | ||
| } | ||
|
|
||
| func (s Statement) classifyCreate() Flag { | ||
| f := IsDDL | ||
| list := mysqlCreateEasy | ||
| if s.dialect == DialectPostgres { | ||
| list = pgCreateEasy | ||
| } | ||
| for _, p := range list { | ||
| if strings.HasPrefix(s.strippedString, p) { | ||
| if !ifExistsRE.MatchString(s.strippedString) { | ||
| f |= IsNonIdempotent | IsEasilyIdempotentFix | ||
| } | ||
| return f | ||
| } | ||
| } | ||
| if !ifExistsRE.MatchString(s.strippedString) { // generic CREATE without IF EXISTS | ||
| f |= IsNonIdempotent | ||
| } | ||
| return f | ||
| } | ||
|
|
||
| func (s Statement) classifyAlter() Flag { | ||
| f := IsDDL | ||
| if ifExistsRE.MatchString(s.strippedString) { | ||
| return f | ||
| } | ||
| f |= IsNonIdempotent | ||
| if s.dialect == DialectPostgres && strings.Contains(s.strippedString, " add column") { | ||
| f |= IsEasilyIdempotentFix | ||
| } | ||
| return f | ||
| } | ||
|
|
||
| func (s Statement) classifyDrop() Flag { | ||
| f := IsDDL | ||
| if ifExistsRE.MatchString(s.strippedString) { | ||
| return f | ||
| } | ||
| f |= IsNonIdempotent | ||
| list := mysqlDropEasy | ||
| if s.dialect == DialectPostgres { | ||
| list = pgDropEasy | ||
| } | ||
| for _, p := range list { | ||
| if strings.HasPrefix(s.strippedString, p) { | ||
| f |= IsEasilyIdempotentFix | ||
| break | ||
| } | ||
| } | ||
| return f | ||
| } | ||
|
|
||
| // Regroup returns groups of statements that are compatible to run together in a single migration. | ||
| // Algorithm (single pass): | ||
| // Postgres: statements with IsMustNonTx must be isolated; all others can group together. | ||
| // MySQL: each DDL (IsDDL) isolated; consecutive/any DML grouped together. | ||
| // Mixed order preserved by starting a new group when incompatibility detected. | ||
| func (s Statements) Regroup() []Statements { | ||
| if len(s) == 0 { | ||
| return nil | ||
| } | ||
| d := s[0].dialect | ||
| var groups []Statements | ||
| var current Statements | ||
| flush := func() { | ||
| if len(current) > 0 { | ||
| groups = append(groups, current) | ||
| current = nil | ||
| } | ||
| } | ||
| for _, st := range s { | ||
| if d == DialectPostgres { | ||
| if st.Flags&IsMustNonTx != 0 { // isolate | ||
| flush() | ||
| groups = append(groups, Statements{st}) | ||
| continue | ||
| } | ||
| current = append(current, st) | ||
| continue | ||
| } | ||
| // MySQL grouping rules | ||
| if st.Flags&IsDDL != 0 { // isolate DDL | ||
| flush() | ||
| groups = append(groups, Statements{st}) | ||
| continue | ||
| } | ||
| // DML -> can merge with existing current if it contains only DML | ||
| if len(current) > 0 { | ||
| current = append(current, st) | ||
| } else { | ||
| current = Statements{st} | ||
| } | ||
| } | ||
| flush() | ||
| // Clear Multi flag from any isolated single-statement group; keep in groups with >1 real statements. | ||
| for gi, g := range groups { | ||
| if g.CountNonEmpty() == 1 { | ||
| // isolated | ||
| for si := range g { | ||
| g[si].Flags &^= IsMultipleStatements | ||
| } | ||
| groups[gi] = g | ||
| } | ||
| } | ||
| return groups | ||
| } | ||
|
|
||
| // CountNonEmpty returns the number of real statements, excluding empty/comment-only and SET-leading statements. | ||
| func (s Statements) CountNonEmpty() int { | ||
| nonEmpty := 0 | ||
| for _, st := range s { | ||
| stripped := st.Tokens.Strip() | ||
| if len(stripped) == 0 { | ||
| continue | ||
| } | ||
| // Exclude SET statements from non-empty counting | ||
| first := strings.ToLower(stripped[0].Text) | ||
| if first == "set" { | ||
| continue | ||
| } | ||
| nonEmpty++ | ||
| } | ||
| return nonEmpty | ||
| } | ||
|
|
||
| // FirstReal returns a pointer to the first real statement (non-empty, non-SET) or nil if none. | ||
| func (s Statements) FirstReal() *Statement { | ||
| for i := range s { | ||
| stripped := s[i].Tokens.Strip() | ||
| if len(stripped) == 0 { | ||
| continue | ||
| } | ||
| first := strings.ToLower(stripped[0].Text) | ||
| if first == "set" { | ||
| continue | ||
| } | ||
| return &s[i] | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| func (s Statements) TokensList() sqltoken.TokensList { | ||
| // Preserve placeholder entries including empty statements to reconstruct original segmentation. | ||
| tl := make(sqltoken.TokensList, len(s)) | ||
| for i, stmt := range s { | ||
| tl[i] = stmt.Tokens | ||
| } | ||
| return tl | ||
| } | ||
|
|
||
| // Summary maps each flag to the first statement Tokens that exhibited it. IsMultipleStatements is synthetic. | ||
| type Summary map[Flag]sqltoken.Tokens | ||
|
|
||
| func (s Summary) Includes(flags ...Flag) bool { | ||
| for _, flag := range flags { | ||
| if _, ok := s[flag]; !ok { | ||
| return false | ||
| } | ||
| } | ||
| return true | ||
| } | ||
|
|
||
| // Summarize builds a Summary from the classified statements. | ||
| func (s Statements) Summarize() Summary { | ||
| out := make(Summary) | ||
| for _, st := range s { | ||
| for _, f := range flagsInOrder { | ||
| if st.Flags&f != 0 { | ||
| if _, exists := out[f]; !exists { | ||
| out[f] = st.Tokens | ||
| } | ||
| } | ||
| } | ||
| } | ||
| return out | ||
| } | ||
|
|
||
| // Names returns human-friendly names of set bits in the flag mask. | ||
| func (f Flag) Names() []string { | ||
| var names []string | ||
| for _, bit := range flagsInOrder { | ||
| if f&bit != 0 { | ||
| names = append(names, flagNameMap[bit]) | ||
| } | ||
| } | ||
| return names | ||
| } | ||
|
|
||
| // isPostgresMustNonTxVersion returns true when a statement must execute outside a transaction for the specified major version. | ||
| func isPostgresMustNonTxVersion(txt string, major int) bool { | ||
| norm := strings.Join(strings.Fields(strings.ToLower(strings.TrimSpace(txt))), " ") | ||
| if major > 0 && major < 12 && strings.HasPrefix(norm, "alter type") && strings.Contains(norm, " add value") { | ||
| return true | ||
| } | ||
| for _, p := range pgMustNonTxPrefixes { | ||
| if strings.HasPrefix(norm, p) { | ||
| return true | ||
| } | ||
| } | ||
| return false | ||
| } | ||
|
|
||
| var ifExistsRE = regexp.MustCompile(`(?i)\bIF\s+(?:NOT\s+)?EXISTS\b`) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just noticing "sql" is prefix in "sqltoken" and suffix in "classifysql". Should it be consistent? Should this package be simply "classify"?