Skip to content

Commit

Permalink
support parse comprehensions in function.
Browse files Browse the repository at this point in the history
  • Loading branch information
cocolato committed Feb 6, 2024
1 parent 2815589 commit 7840ad3
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
20 changes: 20 additions & 0 deletions mako/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ def visit_FunctionDef(self, node):
self._add_declared(node.name)
self._visit_function(node, False)

def visit_ListComp(self, node):
if self.in_function:
if not isinstance(node.elt, _ast.Name):
self.visit(node.elt)
for comp in node.generators:
self.visit(comp.iter)
else:
self.generic_visit(node)

visit_SetComp = visit_GeneratorExp = visit_ListComp

def visit_DictComp(self, node):
if self.in_function:
if not isinstance(node.key, _ast.Name):
self.visit(node.elt)
for comp in node.generators:
self.visit(comp.iter)
else:
self.generic_visit(node)

def _expand_tuples(self, args):
for arg in args:
if isinstance(arg, _ast.Tuple):
Expand Down
36 changes: 36 additions & 0 deletions test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,42 @@ def test_locate_identifiers_17(self):
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.undeclared_identifiers, {"x", "y", "Foo", "Bar"})

def test_locate_identifiers_18(self):
code = """
def func():
return [i for i in range(10)]
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_19(self):
code = """
def func():
return (i for i in range(10))
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_20(self):
code = """
def func():
return {i for i in range(10)}
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_21(self):
code = """
def func():
return {i: i**2 for i in range(10)}
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_no_global_imports(self):
code = """
from foo import *
Expand Down
59 changes: 59 additions & 0 deletions test/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,3 +1717,62 @@ def test_inline_percent(self):
"% foo",
"bar %% baz",
]

def test_lsitcomp_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return [ k for k in x.keys() ]
%>
${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_setcomp_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return { k for k in x.keys() }
%>
${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_generator_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return ( k for k in x.keys())
%>
${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_dictcomp_in_func_strict(self):
t = Template(
"""
<%
def square():
return {i: i**2 for i in range(10)}
%>
${ square()[3] }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["9"]

0 comments on commit 7840ad3

Please sign in to comment.