Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ func (*Load) iStatement() {}
func (*Savepoint) iStatement() {}
func (*RollbackSavepoint) iStatement() {}
func (*ReleaseSavepoint) iStatement() {}
func (*LockTables) iStatement() {}
func (*UnlockTables) iStatement() {}

// ParenSelect can actually not be a top level statement,
// but we have to allow it because it's a requirement
Expand Down Expand Up @@ -5348,6 +5350,86 @@ mustEscape:
buf.WriteByte('`')
}

// LockType is an enum for Lock Types
type LockType string

const (
LockRead LockType = "read"
LockWrite LockType = "write"
LockReadLocal LockType = "read local"
LockLowPriorityWrite LockType = "low_priority write"
)

// TableAndLockType contains table and lock association
type TableAndLockType struct {
Table TableExpr
Lock LockType
SQLNode
}

func (node *TableAndLockType) Format(buf *TrackedBuffer) {
buf.Myprintf("%v %s", node.Table, string(node.Lock))
}

func (node *TableAndLockType) walkSubtree(visit Visit) error {
if node == nil {
return nil
}

return Walk(
visit,
node.Table)
}

type TableAndLockTypes []*TableAndLockType

// LockTables represents the lock statement
type LockTables struct {
Tables TableAndLockTypes
SQLNode
}

func (node *LockTables) Format(buf *TrackedBuffer) {
buf.WriteString("lock tables")
for i, lt := range node.Tables {
if i == 0 {
buf.Myprintf(" %v", lt)
} else {
buf.Myprintf(", %v", lt)
}
}
}

func (node *LockTables) walkSubtree(visit Visit) error {
if node == nil {
return nil
}

for _, t := range node.Tables {
err := Walk(visit, t)
if err != nil {
return err
}
}

return nil
}

// UnlockTables represents the unlock statement
type UnlockTables struct{}

func (node *UnlockTables) Format(buf *TrackedBuffer) {
buf.WriteString("unlock tables")
}

func (node *UnlockTables) walkSubtree(visit Visit) error {
if node == nil {
return nil
}

return nil
}

func compliantName(in string) string {
var buf strings.Builder
for i, c := range in {
Expand Down
44 changes: 38 additions & 6 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1647,12 +1647,6 @@ var (
}, {
input: "optimize foo",
output: "otheradmin",
}, {
input: "lock tables foo",
output: "otheradmin",
}, {
input: "unlock tables foo",
output: "otheradmin",
}, {
input: "select /* EQ true */ 1 from t where a = true",
}, {
Expand Down Expand Up @@ -3533,6 +3527,44 @@ func TestCreateTableSelect(t *testing.T) {
}
}

func TestLocks(t *testing.T) {
testCases := []struct {
input string
output string
}{{
input: "lock tables foo read",
output: "lock tables foo read",
}, {
input: "LOCK TABLES `t1` READ",
output: "lock tables t1 read",
}, {
input: "LOCK TABLES `mytable` as `t` WRITE",
output: "lock tables mytable as t write",
}, {
input: "LOCK TABLES t1 WRITE, t2 READ",
output: "lock tables t1 write, t2 read",
}, {
input: "LOCK TABLES t1 LOW_PRIORITY WRITE, t2 READ LOCAL",
output: "lock tables t1 low_priority write, t2 read local",
}, {
input: "LOCK TABLES t1 as table1 LOW_PRIORITY WRITE, t2 as table2 READ LOCAL",
output: "lock tables t1 as table1 low_priority write, t2 as table2 read local",
}, {
input: "UNLOCK TABLES",
output: "unlock tables",
}, {
input: "LOCK TABLES `people` READ /*!32311 LOCAL */",
output: "lock tables people read local",
}}
for _, tcase := range testCases {
p, err := Parse(tcase.input)
require.NoError(t, err)
if got, want := String(p), tcase.output; got != want {
t.Errorf("Parse(%s):\n%s, want\n%s", tcase.input, got, want)
}
}
}

var (
invalidSQL = []struct {
input string
Expand Down
Loading