Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
358 changes: 358 additions & 0 deletions classifysql/classify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
package classifysql

import (
"errors"
"regexp"
"strings"

"github.com/muir/sqltoken"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just noticing "sql" is prefix in "sqltoken" and suffix in "classifysql". Should it be consistent? Should this package be simply "classify"?

)

// Dialect enumerates supported SQL dialects.
type Dialect int

const (
DialectInvalid Dialect = iota
DialectMySQL
DialectPostgres
DialectSingleStore = DialectMySQL
)

// Flag is a bitmask describing statement / script properties.
type Flag uint32

// Individual flag bits.
const (
IsDDL Flag = 1 << iota
IsDML
IsNonIdempotent
IsMultipleStatements
IsEasilyIdempotentFix
IsMustNonTx // Statement must run outside a transaction (e.g. PG CREATE INDEX CONCURRENTLY)
)

var flagNameMap = map[Flag]string{
IsDDL: "DDL",
IsDML: "DML",
IsNonIdempotent: "NonIdem",
IsMultipleStatements: "Multi",
IsEasilyIdempotentFix: "EasyFix",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's cool repo,
This one I am not quite understand, are you going to do some special action to 'fix'/make it idempotent ( I didn't found the fix code yet) ?

IsMustNonTx: "MustNonTx",
}

var flagsInOrder = []Flag{IsDDL, IsDML, IsNonIdempotent, IsEasilyIdempotentFix, IsMustNonTx, IsMultipleStatements}

// Statement represents one SQL statement with its original, unstripped tokens and classification flags.
type Statement struct {
Flags Flag
Tokens sqltoken.Tokens // unstripped tokens for this statement
dialect Dialect
version int
stripped sqltoken.Tokens // cached stripped tokens (nil/empty if none)
strippedString string // cached stripped string (empty if none), lower case
firstLower string // cached lowercase first token text (empty if none)
}

// Statements is a slice of Statement.
type Statements []Statement

// Postgres statements that inherently require execution outside a transaction.
// Lower-cased, whitespace-collapsed prefixes.
var pgMustNonTxPrefixes = []string{
"create index concurrently",
"create unique index concurrently",
"drop index concurrently",
"refresh materialized view concurrently",
"reindex concurrently",
"vacuum full",
"cluster",
"create database",
"drop database",
"create tablespace",
"drop tablespace",
"create subscription",
"alter subscription",
"drop subscription",
}

var (
mysqlCreateEasy = []string{"create table"}
mysqlDropEasy = []string{"drop table", "drop database"}
pgCreateEasy = []string{"create table", "create index", "create schema", "create sequence", "create extension"}
pgDropEasy = []string{"drop table", "drop index", "drop schema", "drop sequence", "drop extension"}
)

// ClassifyTokens tokenizes and classifies a SQL script, returning per-statement flags and original tokens.
// Only error case is invalid dialect.
// strings.Join(ClassifyTokens().TokensList.Strings, ";") should return the original sqlString
func ClassifyTokens(d Dialect, majorVersion int, sqlString string) (Statements, error) {
if strings.TrimSpace(sqlString) == "" {
return nil, nil
}
var tokens sqltoken.Tokens
switch d {
case DialectPostgres:
// TokenizePostgreSQL preserves comments/whitespace in tokens
tokens = sqltoken.TokenizePostgreSQL(sqlString)
case DialectMySQL:
// TokenizeMySQL preserves comments/whitespace in tokens
tokens = sqltoken.TokenizeMySQL(sqlString)
default:
return nil, errors.New("invalid dialect")
}
split := tokens.CmdSplitUnstripped()
stmts := make(Statements, len(split))
for i, raw := range split {
stmts[i].Tokens = raw
stmts[i].dialect = d
stmts[i].version = majorVersion
stripped := raw.Strip()
stmts[i].stripped = stripped
if len(stripped) == 0 {
// just comments/whitespace
continue
}
stmts[i].strippedString = strings.ToLower(stripped.String())
stmts[i].firstLower = strings.ToLower(stripped[0].Text)
stmts[i].classify()
}
// Add IsMultipleStatements if there are multiple real statements
if stmts.CountNonEmpty() > 1 {
for i := range stmts {
stmts[i].Flags |= IsMultipleStatements
}
}
return stmts, nil
}

// StripString returns the lowercased stripped string
func (s Statement) StripString() string { return s.strippedString }

// FirstWord returns the lower case first word
func (s Statement) FirstWord() string { return s.firstLower }

func (s Statement) Strip() sqltoken.Tokens { return s.stripped }

// classifyFirstVerb applies verb-based classification logic
func (s *Statement) classify() {
switch s.firstLower {
case "create":
s.Flags = s.classifyCreate()
case "alter":
s.Flags = s.classifyAlter()
case "drop":
s.Flags = s.classifyDrop()
case "rename", "comment":
s.Flags = IsDDL | IsNonIdempotent
case "truncate":
s.Flags = IsDDL
case "insert", "update", "delete", "replace", "call", "do", "load", "handler", "import", "with":
s.Flags = IsDML
}
if s.dialect == DialectPostgres && isPostgresMustNonTxVersion(s.strippedString, s.version) {
s.Flags |= IsMustNonTx
}
}

func (s Statement) classifyCreate() Flag {
f := IsDDL
list := mysqlCreateEasy
if s.dialect == DialectPostgres {
list = pgCreateEasy
}
for _, p := range list {
if strings.HasPrefix(s.strippedString, p) {
if !ifExistsRE.MatchString(s.strippedString) {
f |= IsNonIdempotent | IsEasilyIdempotentFix
}
return f
}
}
if !ifExistsRE.MatchString(s.strippedString) { // generic CREATE without IF EXISTS
f |= IsNonIdempotent
}
return f
}

func (s Statement) classifyAlter() Flag {
f := IsDDL
if ifExistsRE.MatchString(s.strippedString) {
return f
}
f |= IsNonIdempotent
if s.dialect == DialectPostgres && strings.Contains(s.strippedString, " add column") {
f |= IsEasilyIdempotentFix
}
return f
}

func (s Statement) classifyDrop() Flag {
f := IsDDL
if ifExistsRE.MatchString(s.strippedString) {
return f
}
f |= IsNonIdempotent
list := mysqlDropEasy
if s.dialect == DialectPostgres {
list = pgDropEasy
}
for _, p := range list {
if strings.HasPrefix(s.strippedString, p) {
f |= IsEasilyIdempotentFix
break
}
}
return f
}

// Regroup returns groups of statements that are compatible to run together in a single migration.
// Algorithm (single pass):
// Postgres: statements with IsMustNonTx must be isolated; all others can group together.
// MySQL: each DDL (IsDDL) isolated; consecutive/any DML grouped together.
// Mixed order preserved by starting a new group when incompatibility detected.
func (s Statements) Regroup() []Statements {
if len(s) == 0 {
return nil
}
d := s[0].dialect
var groups []Statements
var current Statements
flush := func() {
if len(current) > 0 {
groups = append(groups, current)
current = nil
}
}
for _, st := range s {
if d == DialectPostgres {
if st.Flags&IsMustNonTx != 0 { // isolate
flush()
groups = append(groups, Statements{st})
continue
}
current = append(current, st)
continue
}
// MySQL grouping rules
if st.Flags&IsDDL != 0 { // isolate DDL
flush()
groups = append(groups, Statements{st})
continue
}
// DML -> can merge with existing current if it contains only DML
if len(current) > 0 {
current = append(current, st)
} else {
current = Statements{st}
}
}
flush()
// Clear Multi flag from any isolated single-statement group; keep in groups with >1 real statements.
for gi, g := range groups {
if g.CountNonEmpty() == 1 {
// isolated
for si := range g {
g[si].Flags &^= IsMultipleStatements
}
groups[gi] = g
}
}
return groups
}

// CountNonEmpty returns the number of real statements, excluding empty/comment-only and SET-leading statements.
func (s Statements) CountNonEmpty() int {
nonEmpty := 0
for _, st := range s {
stripped := st.Tokens.Strip()
if len(stripped) == 0 {
continue
}
// Exclude SET statements from non-empty counting
first := strings.ToLower(stripped[0].Text)
if first == "set" {
continue
}
nonEmpty++
}
return nonEmpty
}

// FirstReal returns a pointer to the first real statement (non-empty, non-SET) or nil if none.
func (s Statements) FirstReal() *Statement {
for i := range s {
stripped := s[i].Tokens.Strip()
if len(stripped) == 0 {
continue
}
first := strings.ToLower(stripped[0].Text)
if first == "set" {
continue
}
return &s[i]
}
return nil
}

func (s Statements) TokensList() sqltoken.TokensList {
// Preserve placeholder entries including empty statements to reconstruct original segmentation.
tl := make(sqltoken.TokensList, len(s))
for i, stmt := range s {
tl[i] = stmt.Tokens
}
return tl
}

// Summary maps each flag to the first statement Tokens that exhibited it. IsMultipleStatements is synthetic.
type Summary map[Flag]sqltoken.Tokens

func (s Summary) Includes(flags ...Flag) bool {
for _, flag := range flags {
if _, ok := s[flag]; !ok {
return false
}
}
return true
}

// Summarize builds a Summary from the classified statements.
func (s Statements) Summarize() Summary {
out := make(Summary)
for _, st := range s {
for _, f := range flagsInOrder {
if st.Flags&f != 0 {
if _, exists := out[f]; !exists {
out[f] = st.Tokens
}
}
}
}
return out
}

// Names returns human-friendly names of set bits in the flag mask.
func (f Flag) Names() []string {
var names []string
for _, bit := range flagsInOrder {
if f&bit != 0 {
names = append(names, flagNameMap[bit])
}
}
return names
}

// isPostgresMustNonTxVersion returns true when a statement must execute outside a transaction for the specified major version.
func isPostgresMustNonTxVersion(txt string, major int) bool {
norm := strings.Join(strings.Fields(strings.ToLower(strings.TrimSpace(txt))), " ")
if major > 0 && major < 12 && strings.HasPrefix(norm, "alter type") && strings.Contains(norm, " add value") {
return true
}
for _, p := range pgMustNonTxPrefixes {
if strings.HasPrefix(norm, p) {
return true
}
}
return false
}

var ifExistsRE = regexp.MustCompile(`(?i)\bIF\s+(?:NOT\s+)?EXISTS\b`)
Loading
Loading