diff --git a/.gitignore b/.gitignore index e69de29b..723ef36f 100644 --- a/.gitignore +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/bindings.go b/bindings.go index 1c241895..7fe9d690 100644 --- a/bindings.go +++ b/bindings.go @@ -1,6 +1,6 @@ package sitter -//#include "bindings.h" +// #include "bindings.h" import "C" import ( @@ -801,44 +801,48 @@ func NewQuery(pattern []byte, lang *Language) (*Query, error) { q := &Query{c: c} + // Copied from: https://github.com/klothoplatform/go-tree-sitter/commit/e351b20167b26d515627a4a1a884528ede5fef79 + // this is just used for syntax validation - it does not actually filter anything for i := uint32(0); i < q.PatternCount(); i++ { - steps := q.PredicatesForPattern(i) - if len(steps) == 0 { - continue - } - - if steps[0].Type != QueryPredicateStepTypeString { - return nil, errors.New("predicate must begin with a literal value") - } - - operator := q.StringValueForId(steps[0].ValueId) - switch operator { - case "eq?", "not-eq?": - if len(steps) != 4 { - return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2) - } - if steps[1].Type != QueryPredicateStepTypeCapture { - return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId)) - } - case "match?", "not-match?": - if len(steps) != 4 { - return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2) - } - if steps[1].Type != QueryPredicateStepTypeCapture { - return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId)) - } - if steps[2].Type != QueryPredicateStepTypeString { - return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId)) - } - case "set!", "is?", "is-not?": - if len(steps) < 3 || len(steps) > 4 { - return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2) + predicates := q.PredicatesForPattern(i) + for _, steps := range predicates { + if len(steps) == 0 { + continue } - if steps[1].Type != QueryPredicateStepTypeString { - return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId)) + + if steps[0].Type != QueryPredicateStepTypeString { + return nil, errors.New("predicate must begin with a literal value") } - if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString { - return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId)) + + operator := q.StringValueForId(steps[0].ValueId) + switch operator { + case "eq?", "not-eq?": + if len(steps) != 4 { + return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2) + } + if steps[1].Type != QueryPredicateStepTypeCapture { + return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId)) + } + case "match?", "not-match?": + if len(steps) != 4 { + return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 2, got %d", operator, len(steps)-2) + } + if steps[1].Type != QueryPredicateStepTypeCapture { + return nil, fmt.Errorf("first argument of `#%s` predicate must be a capture. Got %s", operator, q.StringValueForId(steps[1].ValueId)) + } + if steps[2].Type != QueryPredicateStepTypeString { + return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId)) + } + case "set!", "is?", "is-not?": + if len(steps) < 3 || len(steps) > 4 { + return nil, fmt.Errorf("wrong number of arguments to `#%s` predicate. Expected 1 or 2, got %d", operator, len(steps)-2) + } + if steps[1].Type != QueryPredicateStepTypeString { + return nil, fmt.Errorf("first argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[1].ValueId)) + } + if len(steps) > 2 && steps[2].Type != QueryPredicateStepTypeString { + return nil, fmt.Errorf("second argument of `#%s` predicate must be a string. Got %s", operator, q.StringValueForId(steps[2].ValueId)) + } } } } @@ -885,7 +889,7 @@ type QueryPredicateStep struct { ValueId uint32 } -func (q *Query) PredicatesForPattern(patternIndex uint32) []QueryPredicateStep { +func (q *Query) PredicatesForPattern(patternIndex uint32) [][]QueryPredicateStep { var ( length C.uint32_t cPredicateSteps []C.TSQueryPredicateStep @@ -905,7 +909,7 @@ func (q *Query) PredicatesForPattern(patternIndex uint32) []QueryPredicateStep { predicateSteps = append(predicateSteps, QueryPredicateStep{stepType, valueId}) } - return predicateSteps + return splitPredicates(predicateSteps) } func (q *Query) CaptureNameForId(id uint32) string { @@ -1059,6 +1063,21 @@ func (qc *QueryCursor) NextCapture() (*QueryMatch, uint32, bool) { return qm, uint32(captureIndex), true } +// Copied From: https://github.com/klothoplatform/go-tree-sitter/commit/e351b20167b26d515627a4a1a884528ede5fef79 + +func splitPredicates(steps []QueryPredicateStep) [][]QueryPredicateStep { + var predicateSteps [][]QueryPredicateStep + var currentSteps []QueryPredicateStep + for _, step := range steps { + currentSteps = append(currentSteps, step) + if step.Type == QueryPredicateStepTypeDone { + predicateSteps = append(predicateSteps, currentSteps) + currentSteps = []QueryPredicateStep{} + } + } + return predicateSteps +} + func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch { qm := &QueryMatch{ ID: m.ID, @@ -1067,87 +1086,90 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch q := qc.q - steps := q.PredicatesForPattern(uint32(qm.PatternIndex)) - if len(steps) == 0 { + predicates := q.PredicatesForPattern(uint32(qm.PatternIndex)) + if len(predicates) == 0 { qm.Captures = m.Captures return qm } - operator := q.StringValueForId(steps[0].ValueId) + // track if we matched all predicates globally + matchedAll := true - switch operator { - case "eq?", "not-eq?": - isPositive := operator == "eq?" + // check each predicate against the match + for _, steps := range predicates { + operator := q.StringValueForId(steps[0].ValueId) - expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId) + switch operator { + case "eq?", "not-eq?": + isPositive := operator == "eq?" - if steps[2].Type == QueryPredicateStepTypeCapture { - expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId) + expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId) - var nodeLeft, nodeRight *Node + if steps[2].Type == QueryPredicateStepTypeCapture { + expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId) - found := false + var nodeLeft, nodeRight *Node - for _, c := range m.Captures { - captureName := q.CaptureNameForId(c.Index) - qm.Captures = append(qm.Captures, c) + for _, c := range m.Captures { + captureName := q.CaptureNameForId(c.Index) - if captureName == expectedCaptureNameLeft { - nodeLeft = c.Node - } - if captureName == expectedCaptureNameRight { - nodeRight = c.Node + if captureName == expectedCaptureNameLeft { + nodeLeft = c.Node + } + if captureName == expectedCaptureNameRight { + nodeRight = c.Node + } + + if nodeLeft != nil && nodeRight != nil { + if (nodeLeft.Content(input) == nodeRight.Content(input)) != isPositive { + matchedAll = false + } + break + } } + } else { + expectedValueRight := q.StringValueForId(steps[2].ValueId) + + for _, c := range m.Captures { + captureName := q.CaptureNameForId(c.Index) - if nodeLeft != nil && nodeRight != nil { - if (nodeLeft.Content(input) == nodeRight.Content(input)) == isPositive { - found = true + if expectedCaptureNameLeft != captureName { + continue + } + + if (c.Node.Content(input) == expectedValueRight) != isPositive { + matchedAll = false + break } - break } } - if !found { - qm.Captures = nil + if matchedAll == false { + break } - } else { - expectedValueRight := q.StringValueForId(steps[2].ValueId) - found := false + case "match?", "not-match?": + isPositive := operator == "match?" + + expectedCaptureName := q.CaptureNameForId(steps[1].ValueId) + regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId)) + for _, c := range m.Captures { captureName := q.CaptureNameForId(c.Index) - - qm.Captures = append(qm.Captures, c) - if expectedCaptureNameLeft != captureName { + if expectedCaptureName != captureName { continue } - if (c.Node.Content(input) == expectedValueRight) == isPositive { - found = true + if regex.Match([]byte(c.Node.Content(input))) != isPositive { + matchedAll = false + break } } - - if !found { - qm.Captures = nil - } } + } - case "match?", "not-match?": - isPositive := operator == "match?" - - expectedCaptureName := q.CaptureNameForId(steps[1].ValueId) - regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId)) - - for _, c := range m.Captures { - captureName := q.CaptureNameForId(c.Index) - if expectedCaptureName != captureName { - continue - } - - if regex.Match([]byte(c.Node.Content(input))) == isPositive { - qm.Captures = append(qm.Captures, c) - } - } + if matchedAll { + qm.Captures = append(qm.Captures, m.Captures...) } return qm diff --git a/predicates_test.go b/predicates_test.go index cabe7e84..c985b677 100644 --- a/predicates_test.go +++ b/predicates_test.go @@ -96,6 +96,12 @@ func TestQueryWithPredicates(t *testing.T) { msg: "#eq?: success test", pattern: `((expression) @capture (#eq? @capture "this"))`, + }, + { + success: true, + msg: "#eq?: success double predicate test", + pattern: `((expression) @capture + (#eq? @capture @capture) (#eq? @capture "this"))`, }, { success: true, @@ -287,6 +293,13 @@ func TestFilterPredicates(t *testing.T) { expectedBefore: 1, expectedAfter: 0, }, + { + input: `// foo`, + query: `((comment) @capture + (#eq? @capture "// foo") (#eq? @capture "// bar"))`, + expectedBefore: 1, + expectedAfter: 0, + }, { input: `1234 + 1234`, query: `((sum @@ -346,6 +359,24 @@ func TestFilterPredicates(t *testing.T) { expectedBefore: 2, expectedAfter: 2, }, + { + input: `1234 + 4321`, + query: `((sum + left: (expression (number) @left) + right: (expression (number) @right)) + (#eq? @left 1234) (#not-eq? @left @right))`, + expectedBefore: 2, + expectedAfter: 2, + }, + { + input: `1234 + 4321`, + query: `((sum + left: (expression (number) @left) + right: (expression (number) @right)) + (#eq? @left 1234) (#eq? @left 4321))`, + expectedBefore: 2, + expectedAfter: 0, + }, } parser := NewParser()