Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic tree typing #967

Merged
merged 11 commits into from
Jan 21, 2022
2 changes: 1 addition & 1 deletion lark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .utils import logger
from .tree import Tree
from .tree import Tree, ParseTree
from .visitors import Transformer, Visitor, v_args, Discard, Transformer_NonRecursive
from .exceptions import (ParseError, LexError, GrammarError, UnexpectedToken,
UnexpectedInput, UnexpectedCharacters, UnexpectedEOF, LarkError)
Expand Down
3 changes: 2 additions & 1 deletion lark/lark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
if TYPE_CHECKING:
from .parsers.lalr_interactive_parser import InteractiveParser
from .tree import ParseTree
from .visitors import Transformer
if sys.version_info >= (3, 8):
from typing import Literal
Expand Down Expand Up @@ -591,7 +592,7 @@ def parse_interactive(self, text: Optional[str]=None, start: Optional[str]=None)
"""
return self.parser.parse_interactive(text, start=start)

def parse(self, text: str, start: Optional[str]=None, on_error: 'Optional[Callable[[UnexpectedInput], bool]]'=None) -> Tree:
def parse(self, text: str, start: Optional[str]=None, on_error: 'Optional[Callable[[UnexpectedInput], bool]]'=None) -> 'ParseTree':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you change this to 'Tree[Token]', there is no need for a ParseTree variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about it not being needed in this instance or not being needed in general?

  • For this specific use case, yes. That would not require the added import of ParseTree when type checking.
  • In general, ParseTree isn't strictly required at all in general. However, I think it is a nice ease of use to a have a specific name for trees returned by parse() since it is a core type that users of the library will be interacting with a lot.

"""Parse the given text, according to the options provided.

Parameters:
Expand Down
4 changes: 2 additions & 2 deletions lark/reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unicodedata

from .lark import Lark
from .tree import Tree
from .tree import Tree, ParseTree
from .visitors import Transformer_InPlace
from .lexer import Token, PatternStr, TerminalDef
from .grammar import Terminal, NonTerminal, Symbol
Expand Down Expand Up @@ -94,7 +94,7 @@ def _reconstruct(self, tree):
else:
yield item

def reconstruct(self, tree: Tree, postproc: Optional[Callable[[Iterable[str]], Iterable[str]]]=None, insert_spaces: bool=True) -> str:
def reconstruct(self, tree: ParseTree, postproc: Optional[Callable[[Iterable[str]], Iterable[str]]]=None, insert_spaces: bool=True) -> str:
x = self._reconstruct(tree)
if postproc:
x = postproc(x)
Expand Down
36 changes: 21 additions & 15 deletions lark/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import sys
from copy import deepcopy

from typing import List, Callable, Iterator, Union, Optional, TYPE_CHECKING
from typing import List, Callable, Iterator, Union, Optional, Generic, TypeVar, Any, TYPE_CHECKING

if TYPE_CHECKING:
from .lexer import TerminalDef
from .lexer import TerminalDef, Token
if sys.version_info >= (3, 8):
from typing import Literal
else:
Expand All @@ -35,7 +35,10 @@ def __init__(self):
self.empty = True


class Tree(object):
_Leaf_T = TypeVar("_Leaf_T")


class Tree(Generic[_Leaf_T]):
erezsh marked this conversation as resolved.
Show resolved Hide resolved
"""The main tree class.

Creates a new tree, and stores "data" and "children" in attributes of the same name.
Expand All @@ -49,9 +52,9 @@ class Tree(object):
"""

data: str
children: 'List[Union[str, Tree]]'
children: 'List[Union[_Leaf_T, Tree[_Leaf_T]]]'

def __init__(self, data: str, children: 'List[Union[str, Tree]]', meta: Optional[Meta]=None) -> None:
def __init__(self, data: str, children: 'List[Union[_Leaf_T, Tree[_Leaf_T]]]', meta: Optional[Meta]=None) -> None:
self.data = data
self.children = children
self._meta = meta
Expand Down Expand Up @@ -100,7 +103,7 @@ def __ne__(self, other):
def __hash__(self) -> int:
return hash((self.data, tuple(self.children)))

def iter_subtrees(self) -> 'Iterator[Tree]':
def iter_subtrees(self) -> 'Iterator[Tree[_Leaf_T]]':
"""Depth-first iteration.

Iterates over all the subtrees, never returning to the same node twice (Lark's parse-tree is actually a DAG).
Expand All @@ -109,17 +112,18 @@ def iter_subtrees(self) -> 'Iterator[Tree]':
subtrees = OrderedDict()
for subtree in queue:
subtrees[id(subtree)] = subtree
queue += [c for c in reversed(subtree.children)
# Reason for type ignore https://github.com/python/mypy/issues/10999
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current type system it isn't possible to convey that "_Leaf_T can be anything except a Tree". So this will likely always be a thing mypy won't be able to figure out on its own. The referenced issue has more context and might provide some insights on how to better handle it.

queue += [c for c in reversed(subtree.children) # type: ignore
if isinstance(c, Tree) and id(c) not in subtrees]

del queue
return reversed(list(subtrees.values()))

def find_pred(self, pred: 'Callable[[Tree], bool]') -> 'Iterator[Tree]':
def find_pred(self, pred: 'Callable[[Tree[_Leaf_T]], bool]') -> 'Iterator[Tree[_Leaf_T]]':
"""Returns all nodes of the tree that evaluate pred(node) as true."""
return filter(pred, self.iter_subtrees())

def find_data(self, data: str) -> 'Iterator[Tree]':
def find_data(self, data: str) -> 'Iterator[Tree[_Leaf_T]]':
"""Returns all nodes of the tree whose data equals the given data."""
return self.find_pred(lambda t: t.data == data)

Expand All @@ -131,7 +135,7 @@ def expand_kids_by_index(self, *indices: int) -> None:
kid = self.children[i]
self.children[i:i+1] = kid.children

def scan_values(self, pred: 'Callable[[Union[str, Tree]], bool]') -> Iterator[str]:
def scan_values(self, pred: 'Callable[[Union[_Leaf_T, Tree[_Leaf_T]]], bool]') -> Iterator[_Leaf_T]:
"""Return all values in the tree that evaluate pred(value) as true.

This can be used to find all the tokens in the tree.
Expand Down Expand Up @@ -164,30 +168,32 @@ def iter_subtrees_topdown(self):
def __deepcopy__(self, memo):
return type(self)(self.data, deepcopy(self.children, memo), meta=self._meta)

def copy(self) -> 'Tree':
def copy(self) -> 'Tree[_Leaf_T]':
return type(self)(self.data, self.children)

def set(self, data: str, children: 'List[Union[str, Tree]]') -> None:
def set(self, data: str, children: 'List[Union[_Leaf_T, Tree[_Leaf_T]]]') -> None:
self.data = data
self.children = children


ParseTree = Tree['Token']


class SlottedTree(Tree):
__slots__ = 'data', 'children', 'rule', '_meta'


def pydot__tree_to_png(tree: Tree, filename: str, rankdir: 'Literal["TB", "LR", "BT", "RL"]'="LR", **kwargs) -> None:
def pydot__tree_to_png(tree: Tree[Any], filename: str, rankdir: 'Literal["TB", "LR", "BT", "RL"]'="LR", **kwargs) -> None:
plannigan marked this conversation as resolved.
Show resolved Hide resolved
graph = pydot__tree_to_graph(tree, rankdir, **kwargs)
graph.write_png(filename)


def pydot__tree_to_dot(tree, filename, rankdir="LR", **kwargs):
def pydot__tree_to_dot(tree: Tree[Any], filename, rankdir="LR", **kwargs):
graph = pydot__tree_to_graph(tree, rankdir, **kwargs)
graph.write(filename)


def pydot__tree_to_graph(tree, rankdir="LR", **kwargs):
def pydot__tree_to_graph(tree: Tree[Any], rankdir="LR", **kwargs):
"""Creates a colorful image that represents the tree (data+children, without meta)

Possible values for `rankdir` are "TB", "LR", "BT", "RL", corresponding to
Expand Down
55 changes: 28 additions & 27 deletions lark/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
###{standalone
from inspect import getmembers, getmro

_T = TypeVar('_T')
_Return_T = TypeVar('_Return_T')
_Leaf_T = TypeVar('_Leaf_T')
_R = TypeVar('_R')
_FUNC = Callable[..., _T]
_FUNC = Callable[..., _Return_T]
_DECORATED = Union[_FUNC, type]

class Discard(Exception):
Expand Down Expand Up @@ -52,7 +53,7 @@ def __class_getitem__(cls, _):
return cls


class Transformer(_Decoratable, ABC, Generic[_T]):
class Transformer(_Decoratable, ABC, Generic[_Return_T, _Leaf_T]):
plannigan marked this conversation as resolved.
Show resolved Hide resolved
"""Transformers visit each node of the tree, and run the appropriate method on it according to the node's data.

Methods are provided by the user via inheritance, and called according to ``tree.data``.
Expand Down Expand Up @@ -131,11 +132,11 @@ def _transform_tree(self, tree):
children = list(self._transform_children(tree.children))
return self._call_userfunc(tree, children)

def transform(self, tree: Tree) -> _T:
def transform(self, tree: Tree[_Leaf_T]) -> _Return_T:
"Transform the given tree, and return the final result"
return self._transform_tree(tree)

def __mul__(self, other: 'Transformer[_T]') -> 'TransformerChain[_T]':
def __mul__(self, other: 'Transformer[_Return_T, _Leaf_T]') -> 'TransformerChain[_Return_T, _Leaf_T]':
"""Chain two transformers together, returning a new transformer.
"""
return TransformerChain(self, other)
Expand All @@ -155,19 +156,19 @@ def __default_token__(self, token):
return token


class TransformerChain(Generic[_T]):
class TransformerChain(Generic[_Return_T, _Leaf_T]):

transformers: Tuple[Transformer[_T], ...]
transformers: Tuple[Transformer[_Return_T, _Leaf_T], ...]

def __init__(self, *transformers: Transformer[_T]) -> None:
def __init__(self, *transformers: Transformer[_Return_T, _Leaf_T]) -> None:
self.transformers = transformers

def transform(self, tree: Tree) -> _T:
def transform(self, tree: Tree[_Leaf_T]) -> _Return_T:
for t in self.transformers:
tree = t.transform(tree)
return tree

def __mul__(self, other: Transformer[_T]) -> 'TransformerChain[_T]':
def __mul__(self, other: Transformer[_Return_T, _Leaf_T]) -> 'TransformerChain[_Return_T, _Leaf_T]':
return TransformerChain(*self.transformers + (other,))


Expand All @@ -179,7 +180,7 @@ class Transformer_InPlace(Transformer):
def _transform_tree(self, tree): # Cancel recursion
return self._call_userfunc(tree)

def transform(self, tree):
def transform(self, tree: Tree[_Leaf_T]) -> _Return_T:
for subtree in tree.iter_subtrees():
subtree.children = list(self._transform_children(subtree.children))

Expand All @@ -194,18 +195,18 @@ class Transformer_NonRecursive(Transformer):
Useful for huge trees.
"""

def transform(self, tree):
def transform(self, tree: Tree[_Leaf_T]) -> _Return_T:
# Tree to postfix
rev_postfix = []
q = [tree]
rev_postfix: List[Union[_Leaf_T, Tree[_Leaf_T]]] = []
plannigan marked this conversation as resolved.
Show resolved Hide resolved
q: List[Union[_Leaf_T, Tree[_Leaf_T]]] = [tree]
while q:
t = q.pop()
rev_postfix.append(t)
if isinstance(t, Tree):
q += t.children

# Postfix to tree
stack = []
stack: List[Union[_Leaf_T, _Return_T]] = []
for x in reversed(rev_postfix):
if isinstance(x, Tree):
size = len(x.children)
Expand All @@ -220,8 +221,8 @@ def transform(self, tree):
else:
stack.append(x)

t ,= stack # We should have only one tree remaining
return t
result, = stack # We should have only one tree remaining
return result
plannigan marked this conversation as resolved.
Show resolved Hide resolved


class Transformer_InPlaceRecursive(Transformer):
Expand All @@ -248,34 +249,34 @@ def __class_getitem__(cls, _):
return cls


class Visitor(VisitorBase, ABC, Generic[_T]):
class Visitor(VisitorBase, ABC, Generic[_Leaf_T]):
"""Tree visitor, non-recursive (can handle huge trees).

Visiting a node calls its methods (provided by the user via inheritance) according to ``tree.data``
"""

def visit(self, tree: Tree) -> Tree:
def visit(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]:
"Visits the tree, starting with the leaves and finally the root (bottom-up)"
for subtree in tree.iter_subtrees():
self._call_userfunc(subtree)
return tree

def visit_topdown(self, tree: Tree) -> Tree:
def visit_topdown(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]:
"Visit the tree, starting at the root, and ending at the leaves (top-down)"
for subtree in tree.iter_subtrees_topdown():
self._call_userfunc(subtree)
return tree


class Visitor_Recursive(VisitorBase):
class Visitor_Recursive(VisitorBase, Generic[_Leaf_T]):
"""Bottom-up visitor, recursive.

Visiting a node calls its methods (provided by the user via inheritance) according to ``tree.data``

Slightly faster than the non-recursive version.
"""

def visit(self, tree: Tree) -> Tree:
def visit(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]:
"Visits the tree, starting with the leaves and finally the root (bottom-up)"
for child in tree.children:
if isinstance(child, Tree):
Expand All @@ -284,7 +285,7 @@ def visit(self, tree: Tree) -> Tree:
self._call_userfunc(tree)
return tree

def visit_topdown(self,tree: Tree) -> Tree:
def visit_topdown(self,tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]:
"Visit the tree, starting at the root, and ending at the leaves (top-down)"
self._call_userfunc(tree)

Expand All @@ -295,7 +296,7 @@ def visit_topdown(self,tree: Tree) -> Tree:
return tree


class Interpreter(_Decoratable, ABC, Generic[_T]):
class Interpreter(_Decoratable, ABC, Generic[_Return_T, _Leaf_T]):
"""Interpreter walks the tree starting at the root.

Visits the tree, starting with the root and finally the leaves (top-down)
Expand All @@ -307,15 +308,15 @@ class Interpreter(_Decoratable, ABC, Generic[_T]):
This allows the user to implement branching and loops.
"""

def visit(self, tree: Tree) -> _T:
def visit(self, tree: Tree[_Leaf_T]) -> _Return_T:
f = getattr(self, tree.data)
wrapper = getattr(f, 'visit_wrapper', None)
if wrapper is not None:
return f.visit_wrapper(f, tree.data, tree.children, tree.meta)
else:
return f(tree)

def visit_children(self, tree: Tree) -> List[_T]:
def visit_children(self, tree: Tree[_Leaf_T]) -> List[_Return_T]:
return [self.visit(child) if isinstance(child, Tree) else child
for child in tree.children]

Expand All @@ -326,7 +327,7 @@ def __default__(self, tree):
return self.visit_children(tree)


_InterMethod = Callable[[Type[Interpreter], _T], _R]
_InterMethod = Callable[[Type[Interpreter], _Return_T], _R]

def visit_children_decor(func: _InterMethod) -> _InterMethod:
"See Interpreter"
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import re
from setuptools import find_packages, setup

__version__ ,= re.findall('__version__ = "(.*)"', open('lark/__init__.py').read())
__version__ ,= re.findall('__version__: str = "(.*)"', open('lark/__init__.py').read())

setup(
name = "lark-parser",
version = __version__,
packages = ['lark', 'lark.parsers', 'lark.tools', 'lark.grammars', 'lark.__pyinstaller', 'lark-stubs'],
packages = ['lark', 'lark.parsers', 'lark.tools', 'lark.grammars', 'lark.__pyinstaller'],

requires = [],
install_requires = [],
Expand All @@ -20,7 +20,7 @@
"atomic_cache": ["atomicwrites"],
},

package_data = {'': ['*.md', '*.lark'], 'lark-stubs': ['*.pyi']},
package_data = {'': ['*.md', '*.lark']},

test_suite = 'tests.__main__',

Expand Down