Skip to content

Commit

Permalink
Fix predicate handling to match tree-sitter cli
Browse files Browse the repository at this point in the history
  • Loading branch information
wesen committed Apr 23, 2023
1 parent 9e6836f commit 3eeb0b6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 15 deletions.
42 changes: 30 additions & 12 deletions bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -1065,27 +1065,32 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch
PatternIndex: m.PatternIndex,
}

steps := qc.q.PredicatesForPattern(uint32(qm.PatternIndex))
q := qc.q

steps := q.PredicatesForPattern(uint32(qm.PatternIndex))
if len(steps) == 0 {
qm.Captures = m.Captures
return qm
}

operator := qc.q.StringValueForId(steps[0].ValueId)
operator := q.StringValueForId(steps[0].ValueId)

switch operator {
case "eq?", "not-eq?":
isPositive := operator == "eq?"

expectedCaptureNameLeft := qc.q.CaptureNameForId(steps[1].ValueId)
expectedCaptureNameLeft := q.CaptureNameForId(steps[1].ValueId)

if steps[2].Type == QueryPredicateStepTypeCapture {
expectedCaptureNameRight := qc.q.CaptureNameForId(steps[2].ValueId)
expectedCaptureNameRight := q.CaptureNameForId(steps[2].ValueId)

var nodeLeft, nodeRight *Node

found := false

for _, c := range m.Captures {
captureName := qc.q.CaptureNameForId(c.Index)
captureName := q.CaptureNameForId(c.Index)
qm.Captures = append(qm.Captures, c)

if captureName == expectedCaptureNameLeft {
nodeLeft = c.Node
Expand All @@ -1096,33 +1101,45 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch

if nodeLeft != nil && nodeRight != nil {
if (nodeLeft.Content(input) == nodeRight.Content(input)) == isPositive {
qm.Captures = append(qm.Captures, c)
found = true
}
break
}
}

if !found {
qm.Captures = nil
}
} else {
expectedValueRight := qc.q.StringValueForId(steps[2].ValueId)
expectedValueRight := q.StringValueForId(steps[2].ValueId)

found := false
for _, c := range m.Captures {
captureName := qc.q.CaptureNameForId(c.Index)
captureName := q.CaptureNameForId(c.Index)

qm.Captures = append(qm.Captures, c)
if expectedCaptureNameLeft != captureName {
continue
}

if (c.Node.Content(input) == expectedValueRight) == isPositive {
qm.Captures = append(qm.Captures, c)
found = true
}
}

if !found {
qm.Captures = nil
}
}

case "match?", "not-match?":
isPositive := operator == "match?"

expectedCaptureName := qc.q.CaptureNameForId(steps[1].ValueId)
regex := regexp.MustCompile(qc.q.StringValueForId(steps[2].ValueId))
expectedCaptureName := q.CaptureNameForId(steps[1].ValueId)
regex := regexp.MustCompile(q.StringValueForId(steps[2].ValueId))

for _, c := range m.Captures {
captureName := qc.q.CaptureNameForId(c.Index)
captureName := q.CaptureNameForId(c.Index)
if expectedCaptureName != captureName {
continue
}
Expand All @@ -1134,6 +1151,7 @@ func (qc *QueryCursor) FilterPredicates(m *QueryMatch, input []byte) *QueryMatch
}

return qm

}

// keeps callbacks for parser.parse method
Expand Down
3 changes: 2 additions & 1 deletion bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ func TestQueryError(t *testing.T) {

assert.Nil(q)
assert.NotNil(err)
assert.EqualValues(&QueryError{Offset: 0x02, Type: QueryErrorNodeType}, err)
assert.EqualValues(&QueryError{Offset: 0x02, Type: QueryErrorNodeType,
Message: "invalid node type 'unknown' at line 1 column 0"}, err)
}

func doWorkLifetime(t testing.TB, n *Node) {
Expand Down
13 changes: 11 additions & 2 deletions predicates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func TestFilterPredicates(t *testing.T) {
right: (expression (number) @right))
(#eq? @left @right))`,
expectedBefore: 2,
expectedAfter: 1,
expectedAfter: 2,
},
{
input: `1234 + 4321`,
Expand Down Expand Up @@ -335,7 +335,16 @@ func TestFilterPredicates(t *testing.T) {
right: (expression (number) @right))
(#not-eq? @left @right))`,
expectedBefore: 2,
expectedAfter: 1,
expectedAfter: 2,
},
{
input: `1234 + 4321`,
query: `((sum
left: (expression (number) @left)
right: (expression (number) @right))
(#eq? @left 1234))`,
expectedBefore: 2,
expectedAfter: 2,
},
}

Expand Down

0 comments on commit 3eeb0b6

Please sign in to comment.