Skip to content

Commit

Permalink
bpo-38870: Don't start generated output with newlines in ast.unparse (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
isidentical authored May 3, 2020
1 parent 3dd2157 commit 493bf1c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
14 changes: 10 additions & 4 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,10 +669,16 @@ def items_view(self, traverser, items):
else:
self.interleave(lambda: self.write(", "), traverser, items)

def maybe_newline(self):
"""Adds a newline if it isn't the start of generated source"""
if self._source:
self.write("\n")

def fill(self, text=""):
"""Indent a piece of text and append it, according to the current
indentation level"""
self.write("\n" + " " * self._indent + text)
self.maybe_newline()
self.write(" " * self._indent + text)

def write(self, text):
"""Append a piece of text"""
Expand Down Expand Up @@ -916,7 +922,7 @@ def visit_ExceptHandler(self, node):
self.traverse(node.body)

def visit_ClassDef(self, node):
self.write("\n")
self.maybe_newline()
for deco in node.decorator_list:
self.fill("@")
self.traverse(deco)
Expand Down Expand Up @@ -946,7 +952,7 @@ def visit_AsyncFunctionDef(self, node):
self._function_helper(node, "async def")

def _function_helper(self, node, fill_suffix):
self.write("\n")
self.maybe_newline()
for deco in node.decorator_list:
self.fill("@")
self.traverse(deco)
Expand Down Expand Up @@ -1043,7 +1049,7 @@ def _fstring_FormattedValue(self, node, write):
write("{")
unparser = type(self)()
unparser.set_precedence(_Precedence.TEST.next(), node.value)
expr = unparser.visit(node.value).rstrip("\n")
expr = unparser.visit(node.value)
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
write(expr)
Expand Down
12 changes: 5 additions & 7 deletions Lib/test/test_unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,17 @@ def check_ast_roundtrip(self, code1, **kwargs):
def check_invalid(self, node, raises=ValueError):
self.assertRaises(raises, ast.unparse, node)

def get_source(self, code1, code2=None, strip=True):
def get_source(self, code1, code2=None):
code2 = code2 or code1
code1 = ast.unparse(ast.parse(code1))
if strip:
code1 = code1.strip()
return code1, code2

def check_src_roundtrip(self, code1, code2=None, strip=True):
code1, code2 = self.get_source(code1, code2, strip)
def check_src_roundtrip(self, code1, code2=None):
code1, code2 = self.get_source(code1, code2)
self.assertEqual(code2, code1)

def check_src_dont_roundtrip(self, code1, code2=None, strip=True):
code1, code2 = self.get_source(code1, code2, strip)
def check_src_dont_roundtrip(self, code1, code2=None):
code1, code2 = self.get_source(code1, code2)
self.assertNotEqual(code2, code1)

class UnparseTestCase(ASTTestCase):
Expand Down

0 comments on commit 493bf1c

Please sign in to comment.