Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support comprehensions inside functions when use strict_undefined flag. #386

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/build/unreleased/320.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. change::
:tags: bug, parser
:tickets: 320

Fixed unexpected syntax error in strict_undefined mode that occurred
when using comprehensions within a function in a Mako Python code block.
Now, the local variable in comprehensions won't be added to the checklist
when using strict_undefined mode.
Pull request courtesy Hai Zhu.
20 changes: 20 additions & 0 deletions mako/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ def visit_FunctionDef(self, node):
self._add_declared(node.name)
self._visit_function(node, False)

def visit_ListComp(self, node):
if self.in_function:
if not isinstance(node.elt, _ast.Name):
self.visit(node.elt)
for comp in node.generators:
self.visit(comp.iter)
else:
self.generic_visit(node)

visit_SetComp = visit_GeneratorExp = visit_ListComp

def visit_DictComp(self, node):
if self.in_function:
if not isinstance(node.key, _ast.Name):
self.visit(node.elt)
for comp in node.generators:
self.visit(comp.iter)
else:
self.generic_visit(node)

def _expand_tuples(self, args):
for arg in args:
if isinstance(arg, _ast.Tuple):
Expand Down
36 changes: 36 additions & 0 deletions test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,42 @@ def test_locate_identifiers_17(self):
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.undeclared_identifiers, {"x", "y", "Foo", "Bar"})

def test_locate_identifiers_18(self):
code = """
def func():
return [i for i in range(10)]
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_19(self):
code = """
def func():
return (i for i in range(10))
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_20(self):
code = """
def func():
return {i for i in range(10)}
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_21(self):
code = """
def func():
return {i: i**2 for i in range(10)}
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_no_global_imports(self):
code = """
from foo import *
Expand Down
59 changes: 59 additions & 0 deletions test/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,3 +1717,62 @@ def test_inline_percent(self):
"% foo",
"bar %% baz",
]

def test_listcomp_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return [ k for k in x.keys() ]
%>

${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_setcomp_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return { k for k in x.keys() }
%>

${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_generator_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return ( k for k in x.keys())
%>

${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_dictcomp_in_func_strict(self):
t = Template(
"""
<%
def square():
return {i: i**2 for i in range(10)}
%>

${ square()[3] }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["9"]
Loading