Skip to content

Commit

Permalink
better pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgenus committed Aug 21, 2022
1 parent 1903c28 commit 9f8666b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 85 deletions.
169 changes: 87 additions & 82 deletions asty/visitors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from collections import (
defaultdict,
deque,
)
from functools import wraps
from typing import (
Iterable,
Optional,
Sequence,
)

from devtools import sformat

from asty.nodes import (
BaseNode,
MatchRuleNode,
Node,
)


Expand Down Expand Up @@ -72,68 +76,80 @@ def generic_visit(self, node):
return node


class Matcher:
sub_matchers: dict[str, list['Matcher']]

def __init__(self, pattern: BaseNode):
class MatchingResult:
def __init__(self, pattern: BaseNode, node: BaseNode, context: str = None):
self.pattern = pattern
self.sub_matchers = defaultdict(list)
self.node = node
self.context = context
self.matches: list['MatchingResult'] = []

def attach(self, *sub_match: 'MatchingResult'):
self.matches.extend(sub_match)
return self

def __pretty__(self, fmt, **_kwargs):
yield sformat('MatchingResult', sformat.bold) + '('
yield 1
yield 'context='
yield sformat(repr(self.context), sformat.green)
yield ','
yield 0
yield 'pattern='
yield sformat(repr(self.pattern.node_type), sformat.blue, sformat.italic)
yield ','
yield 0
yield 'node='
yield sformat(repr(self.node.node_type), sformat.yellow, sformat.italic)
yield ','
yield 0
if self.matches:
yield 'matchers='
yield fmt(self.matches)
yield ','
yield 0
yield -1
yield ')'


def make_matcher(self, context: str, pattern: BaseNode) -> 'Matcher':
matcher = Matcher(pattern)
Result = Iterable[MatchingResult]

def decorator(func):
@wraps(func)
def wrapper(node):
result = func(node)
if result:
self.sub_matchers[context].append(matcher)
return result
return wrapper

matcher.match = decorator(matcher.match)
class Matcher:
def __init__(self, pattern: BaseNode, context: Optional[str] = None):
self.pattern = pattern
self.context = context

def make_matcher(self, pattern: BaseNode, context: str) -> 'Matcher':
return Matcher(pattern, context)

return matcher
def make_result(self, node: BaseNode) -> MatchingResult:
return MatchingResult(self.pattern, node, self.context)

def match(self, node: BaseNode):
def match(self, node: BaseNode) -> Result:
method = 'match_' + self.pattern.__class__.__name__
matcher = getattr(self, method, self.generic_match)
return matcher(node)
return list(matcher(node))

def single_match(self, node: BaseNode):
def _single_match(self, node: BaseNode) -> Result:
assert isinstance(self.pattern, MatchRuleNode)
for rule in self.pattern.rules:
sub_matcher = self.make_matcher(self.pattern.name, rule)
if sub_matcher.match(node):
return True
return False

def search_match(self, tree: BaseNode):
search_matches = [
self.single_match(node)
for node in walk(tree)
]
return any(search_matches)

def match_MatchRuleNode(self, node: BaseNode):
sub_matcher = self.make_matcher(rule, self.pattern.name)
if sub_match := sub_matcher.match(node):
yield self.make_result(node).attach(*sub_match)

def _search_match(self, tree: BaseNode) -> Result:
for node in walk(tree):
yield from self._single_match(node)

def match_MatchRuleNode(self, node: BaseNode) -> Result:
assert isinstance(self.pattern, MatchRuleNode)
if self.pattern.exact:
return self.single_match(node)
yield from self._single_match(node)
else:
return self.search_match(node)

def field_match(self, node, name):
value = getattr(node, name, None)
if value is None:
return False
elif isinstance(value, Sequence):
items_matches = [self.match(item) for item in value]
return any(items_matches)
else:
return self.match(value)
yield from self._search_match(node)

def generic_match(self, node):
complex_fields = {
def generic_match(self, node: BaseNode) -> Result:
complex_fields: dict[str, Node] = {
name: value
for name, value in iter_fields(self.pattern)
}
Expand All @@ -144,39 +160,28 @@ def generic_match(self, node):
pattern_value = getattr(self.pattern, name)
node_value = getattr(node, name, None)
if pattern_value != node_value:
return False
return

for name, value in complex_fields.items():
if isinstance(value, Sequence):
for item in value:
sub_matcher = self.make_matcher(name, item)
if not sub_matcher.field_match(node, name):
return False
else:
sub_matcher = self.make_matcher(name, value)
if not sub_matcher.field_match(node, name):
return False
sub_matches: list[MatchingResult] = []

return True
for name, pattern_value in complex_fields.items():
items = pattern_value if isinstance(pattern_value, Sequence) else [pattern_value]
for item in items:
sub_matcher = self.make_matcher(item, name)
node_value = getattr(node, name, None)
if node_value is None:
continue

def __pretty__(self, fmt, **kwargs):
yield 'Matcher('
yield 1
yield 'type='
yield fmt(self.pattern.node_type)
yield ','
yield 0
if self.sub_matchers:
yield 'matchers='
yield fmt(dict(self.sub_matchers))
yield ','
yield 0
# for name, value in self.sub_matchers.items():
# for item in value:
# yield name
# yield '='
# yield fmt(item)
# yield ','
# yield 0
yield -1
yield ')'
sub_match = []
if isinstance(node_value, Sequence):
for node_item in node_value:
sub_match.extend(sub_matcher.match(node_item))
if isinstance(node_value, BaseNode):
sub_match.extend(sub_matcher.match(node_value))

if not sub_match:
return

sub_matches.extend(sub_match)

yield self.make_result(node).attach(*sub_matches)
8 changes: 5 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from devtools import pprint
from devtools import (
pprint,
)

from asty.nodes import (
BasicLitNode,
Expand Down Expand Up @@ -62,8 +64,8 @@ def visit_CallExprNode(self, node):
# pprint(tree)

matcher = Matcher(pattern)
matcher.match(tree)
pprint(matcher)
match = matcher.match(tree)
pprint(match)

# output = "/Users/evgenus/tfc/asty/output-processed.json"
# data = tree.json(
Expand Down

0 comments on commit 9f8666b

Please sign in to comment.