Skip to content

Commit

Permalink
Infer starred expressions in tuple, list, set and dict literals (pyli…
Browse files Browse the repository at this point in the history
  • Loading branch information
rogalski authored and brycepg committed Feb 27, 2018
1 parent f8a4367 commit 775013a
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 9 deletions.
5 changes: 4 additions & 1 deletion ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,12 @@ Change log for the astroid package (used to be astng)

* Fix metaclass detection, when multiple keyword arguments
are used in class definition.

* Add support for annotated variable assignments (PEP 526)

* Starred expressions are now inferred correctly for tuple,
list, set, and dictionary literals.


2015-11-29 -- 1.4.1

Expand Down
73 changes: 69 additions & 4 deletions astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,78 @@ def infer_end(self, context=None):
nodes.FunctionDef._infer = infer_end
nodes.Lambda._infer = infer_end
nodes.Const._infer = infer_end
nodes.List._infer = infer_end
nodes.Tuple._infer = infer_end
nodes.Dict._infer = infer_end
nodes.Set._infer = infer_end
nodes.Slice._infer = infer_end


def infer_seq(self, context=None):
if not any(isinstance(e, nodes.Starred) for e in self.elts):
yield self
else:
values = _infer_seq(self, context)
new_seq = type(self)(self.lineno, self.col_offset, self.parent)
new_seq.postinit(values)
yield new_seq


def _infer_seq(node, context=None):
"""Infer all values based on _BaseContainer.elts"""
values = []

for elt in node.elts:
if isinstance(elt, nodes.Starred):
starred = helpers.safe_infer(elt.value, context)
if starred in (None, util.Uninferable):
raise exceptions.InferenceError(node=node,
context=context)
if not hasattr(starred, 'elts'):
raise exceptions.InferenceError(node=node,
context=context)
values.extend(_infer_seq(starred))
else:
values.append(elt)
return values


nodes.List._infer = infer_seq
nodes.Tuple._infer = infer_seq
nodes.Set._infer = infer_seq


def infer_map(self, context=None):
if not any(isinstance(k, nodes.DictUnpack) for k, _ in self.items):
yield self
else:
items = _infer_map(self, context)
new_seq = type(self)(self.lineno, self.col_offset, self.parent)
new_seq.postinit(list(items.items()))
yield new_seq


def _infer_map(node, context):
"""Infer all values based on Dict.items"""
values = {}
for name, value in node.items:
if isinstance(name, nodes.DictUnpack):
double_starred = helpers.safe_infer(value, context)
if double_starred in (None, util.Uninferable):
raise exceptions.InferenceError
if not isinstance(double_starred, nodes.Dict):
raise exceptions.InferenceError(node=node,
context=context)
values.update(_infer_map(double_starred, context))
else:
key = helpers.safe_infer(name, context=context)
value = helpers.safe_infer(value, context=context)
if key is None or value is None:
raise exceptions.InferenceError(node=node,
context=context)
values[key] = value
return values


nodes.Dict._infer = infer_map


def _higher_function_scope(node):
""" Search for the first function which encloses the given
scope. This can be used for looking up in that function's
Expand Down
119 changes: 115 additions & 4 deletions astroid/tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,117 @@ def test_tuple_builtin_inference(self):
self.assertIsInstance(inferred, Instance)
self.assertEqual(inferred.qname(), "{}.tuple".format(BUILTINS))

@test_utils.require_version('3.5')
def test_starred_in_tuple_literal(self):
code = """
var = (1, 2, 3)
bar = (5, 6, 7)
foo = [999, 1000, 1001]
(0, *var) #@
(0, *var, 4) #@
(0, *var, 4, *bar) #@
(0, *var, 4, *(*bar, 8)) #@
(0, *var, 4, *(*bar, *foo)) #@
"""
ast = extract_node(code, __name__)
self.assertInferTuple(ast[0], [0, 1, 2, 3])
self.assertInferTuple(ast[1], [0, 1, 2, 3, 4])
self.assertInferTuple(ast[2], [0, 1, 2, 3, 4, 5, 6, 7])
self.assertInferTuple(ast[3], [0, 1, 2, 3, 4, 5, 6, 7, 8])
self.assertInferTuple(ast[4], [0, 1, 2, 3, 4, 5, 6, 7, 999, 1000, 1001])

@test_utils.require_version('3.5')
def test_starred_in_list_literal(self):
code = """
var = (1, 2, 3)
bar = (5, 6, 7)
foo = [999, 1000, 1001]
[0, *var] #@
[0, *var, 4] #@
[0, *var, 4, *bar] #@
[0, *var, 4, *[*bar, 8]] #@
[0, *var, 4, *[*bar, *foo]] #@
"""
ast = extract_node(code, __name__)
self.assertInferList(ast[0], [0, 1, 2, 3])
self.assertInferList(ast[1], [0, 1, 2, 3, 4])
self.assertInferList(ast[2], [0, 1, 2, 3, 4, 5, 6, 7])
self.assertInferList(ast[3], [0, 1, 2, 3, 4, 5, 6, 7, 8])
self.assertInferList(ast[4], [0, 1, 2, 3, 4, 5, 6, 7, 999, 1000, 1001])

@test_utils.require_version('3.5')
def test_starred_in_set_literal(self):
code = """
var = (1, 2, 3)
bar = (5, 6, 7)
foo = [999, 1000, 1001]
{0, *var} #@
{0, *var, 4} #@
{0, *var, 4, *bar} #@
{0, *var, 4, *{*bar, 8}} #@
{0, *var, 4, *{*bar, *foo}} #@
"""
ast = extract_node(code, __name__)
self.assertInferSet(ast[0], [0, 1, 2, 3])
self.assertInferSet(ast[1], [0, 1, 2, 3, 4])
self.assertInferSet(ast[2], [0, 1, 2, 3, 4, 5, 6, 7])
self.assertInferSet(ast[3], [0, 1, 2, 3, 4, 5, 6, 7, 8])
self.assertInferSet(ast[4], [0, 1, 2, 3, 4, 5, 6, 7, 999, 1000, 1001])

@test_utils.require_version('3.5')
def test_starred_in_literals_inference_issues(self):
code = """
{0, *var} #@
{0, *var, 4} #@
{0, *var, 4, *bar} #@
{0, *var, 4, *{*bar, 8}} #@
{0, *var, 4, *{*bar, *foo}} #@
"""
ast = extract_node(code, __name__)
for node in ast:
with self.assertRaises(InferenceError):
next(node.infer())

@test_utils.require_version('3.5')
def test_starred_in_mapping_literal(self):
code = """
var = {1: 'b', 2: 'c'}
bar = {4: 'e', 5: 'f'}
{0: 'a', **var} #@
{0: 'a', **var, 3: 'd'} #@
{0: 'a', **var, 3: 'd', **{**bar, 6: 'g'}} #@
"""
ast = extract_node(code, __name__)
self.assertInferDict(ast[0], {0: 'a', 1: 'b', 2: 'c'})
self.assertInferDict(ast[1], {0: 'a', 1: 'b', 2: 'c', 3: 'd'})
self.assertInferDict(ast[2], {0: 'a', 1: 'b', 2: 'c', 3: 'd',
4: 'e', 5: 'f', 6: 'g'})

@test_utils.require_version('3.5')
def test_starred_in_mapping_inference_issues(self):
code = """
{0: 'a', **var} #@
{0: 'a', **var, 3: 'd'} #@
{0: 'a', **var, 3: 'd', **{**bar, 6: 'g'}} #@
"""
ast = extract_node(code, __name__)
for node in ast:
with self.assertRaises(InferenceError):
next(node.infer())

@test_utils.require_version('3.5')
def test_starred_in_mapping_literal_non_const_keys_values(self):
code = """
a, b, c, d, e, f, g, h, i, j = "ABCDEFGHIJ"
var = {c: d, e: f}
bar = {i: j}
{a: b, **var} #@
{a: b, **var, **{g: h, **bar}} #@
"""
ast = extract_node(code, __name__)
self.assertInferDict(ast[0], {"A": "B", "C": "D", "E": "F"})
self.assertInferDict(ast[1], {"A": "B", "C": "D", "E": "F", "G": "H", "I": "J"})

def test_frozenset_builtin_inference(self):
code = """
var = (1, 2)
Expand Down Expand Up @@ -2036,7 +2147,7 @@ class LambdaInstance(object):
__pos__ = lambda self: self.lala
__neg__ = lambda self: self.lala + 1
@property
def lala(self): return 24
def lala(self): return 24
instance = GoodInstance()
lambda_instance = LambdaInstance()
+instance #@
Expand Down Expand Up @@ -2807,7 +2918,7 @@ class NonIndex(object):
a = [1, 2, 3, 4]
a[Index()] #@
a[LambdaIndex()] #@
a[NonIndex()] #@
a[NonIndex()] #@
''')
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, nodes.Const)
Expand Down Expand Up @@ -3206,7 +3317,7 @@ def test_metaclass_subclasses_arguments_are_classes_not_instances(self):
ast_node = extract_node('''
class A(type):
def test(cls):
return cls
return cls
import six
@six.add_metaclass(A)
class B(object):
Expand All @@ -3225,7 +3336,7 @@ def __call__(cls):
cls #@
class B(object):
def __call__(cls):
cls #@
cls #@
''')
first = next(ast_nodes[0].infer())
self.assertIsInstance(first, nodes.ClassDef)
Expand Down

0 comments on commit 775013a

Please sign in to comment.