Skip to content

Create a NamedTuple instance for NamedTuples #1306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ Release date: TBA
* On Python versions >= 3.9, ``astroid`` now understands subscripting
builtin classes such as ``enumerate`` or ``staticmethod``.

* Instances of NamedTuples both from typing and collections will now be cast to a
NamedTuple instance. This instance proxies the definition of that NamedTuple.

* Fixed inference of ``Enums`` when they are imported under an alias.

Closes PyCQA/pylint#5776
Expand Down
4 changes: 4 additions & 0 deletions astroid/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,7 @@ def __repr__(self):

def __str__(self):
return f"AsyncGenerator({self._proxied.name})"


class NamedTuple(BaseInstance):
"""Special node representing a NamedTuple instance"""
39 changes: 25 additions & 14 deletions astroid/brain/brain_namedtuple_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,21 +186,30 @@ def infer_named_tuple(
node: nodes.Call, context: InferenceContext | None = None
) -> Iterator[nodes.ClassDef]:
"""Specific inference function for namedtuple Call node"""
tuple_base_name = nodes.Name(name="tuple", parent=node.root())
class_node, name, attributes = infer_func_form(
node, tuple_base_name, context=context
)
# Infer which type of NamedTuple we're dealing with (typing or collections)
inferred_namedtuple_call = next(node.func.infer())
base_names = [
nodes.Name(name="tuple", parent=node.root()),
nodes.Name(
name=inferred_namedtuple_call.name,
parent=inferred_namedtuple_call.root(),
),
]

class_node, name, attributes = infer_func_form(node, base_names, context=context)
call_site = arguments.CallSite.from_call(node, context=context)
node = extract_node("import collections; collections.namedtuple")
try:

func = next(node.infer())
except StopIteration as e:
raise InferenceError(node=node) from e
try:
rename = next(call_site.infer_argument(func, "rename", context)).bool_value()
rename = next(
call_site.infer_argument(inferred_namedtuple_call, "rename", context)
).bool_value()
except (InferenceError, StopIteration):
rename = False
# If inferred_namedtuple_call is the ClassDef of typing.NamedTuple
# infer_argument will raise AttributeError
# TODO: See if this exception can be prevented
except AttributeError:
rename = False

try:
attributes = _check_namedtuple_attributes(name, attributes, rename)
Expand Down Expand Up @@ -331,7 +340,7 @@ def value(self):
__members__ = ['']
"""
)
class_node = infer_func_form(node, enum_meta, context=context, enum=True)[0]
class_node = infer_func_form(node, [enum_meta], context=context, enum=True)[0]
return iter([class_node.instantiate_class()])


Expand Down Expand Up @@ -509,9 +518,11 @@ def infer_typing_namedtuple_function(node, context=None):
def infer_typing_namedtuple(
node: nodes.Call, context: InferenceContext | None = None
) -> Iterator[nodes.ClassDef]:
"""Infer a typing.NamedTuple(...) call."""
# This is essentially a namedtuple with different arguments
# so we extract the args and infer a named tuple.
"""Infer a typing.NamedTuple(...) call.

We do some premature checking of the node to see if we don't run into any unexpected
values.
"""
try:
func = next(node.func.infer())
except (InferenceError, StopIteration) as exc:
Expand Down
8 changes: 8 additions & 0 deletions astroid/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import sys
from pathlib import Path

if sys.version_info >= (3, 8):
from typing import Final
else:
from typing_extensions import Final

PY38 = sys.version_info[:2] == (3, 8)
PY38_PLUS = sys.version_info >= (3, 8)
PY39_PLUS = sys.version_info >= (3, 9)
Expand Down Expand Up @@ -33,3 +38,6 @@ class Context(enum.Enum):

ASTROID_INSTALL_DIRECTORY = Path(__file__).parent
BRAIN_MODULES_DIRECTORY = ASTROID_INSTALL_DIRECTORY / "brain"

NAMEDTUPLE_BASENAMES: Final[frozenset[str]] = frozenset(("namedtuple", "NamedTuple"))
"""Const used to identify namedtuples in the basenames of subclasses"""
26 changes: 13 additions & 13 deletions astroid/filter_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,27 @@ def _filter_stmts(base_node: nodes.NodeNG, stmts, frame, offset):
#
# def test(b=1):
# ...
#
# If the frame is already a Module we don't need to go up anymore
if (
base_node.parent
not isinstance(myframe, nodes.Module)
and base_node.parent
and base_node.statement(future=True) is myframe
and myframe.parent
):
myframe = myframe.parent.frame()

# We can use line filtering if we are in the same frame.
# mylineno is 0 by default to skip if we can't determine lineno
# or if we are at the module level. lineno information is (for example)
# missing for nodes inserted for living objects.
mylineno = 0
mystmt: nodes.Statement | None = None
if base_node.parent:
if base_node.parent and not isinstance(base_node.parent, nodes.Module):
mystmt = base_node.statement(future=True)

# line filtering if we are in the same frame
#
# take care node may be missing lineno information (this is the case for
# nodes inserted for living objects)
if myframe is frame and mystmt and mystmt.fromlineno is not None:
assert mystmt.fromlineno is not None, mystmt
mylineno = mystmt.fromlineno + offset
else:
# disabling lineno filtering
mylineno = 0
if myframe is frame and mystmt.fromlineno is not None:
assert mystmt.fromlineno is not None, mystmt
mylineno = mystmt.fromlineno + offset

_stmts = []
_stmt_parents = []
Expand Down
4 changes: 3 additions & 1 deletion astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from astroid import bases
from astroid import decorators as decorators_mod
from astroid import util
from astroid.const import IS_PYPY, PY38, PY38_PLUS, PY39_PLUS
from astroid.const import IS_PYPY, NAMEDTUPLE_BASENAMES, PY38, PY38_PLUS, PY39_PLUS
from astroid.context import (
CallContext,
InferenceContext,
Expand Down Expand Up @@ -2521,6 +2521,8 @@ def instantiate_class(self):
return objects.ExceptionInstance(self)
except MroError:
pass
if any(i in NAMEDTUPLE_BASENAMES for i in self.basenames):
return bases.NamedTuple(self)
return bases.Instance(self)

def getattr(self, name, context=None, class_context=True):
Expand Down
9 changes: 5 additions & 4 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,7 +1502,8 @@ class X(NamedTuple("X", [("a", int), ("b", str), ("c", bytes)])):
"""
)
self.assertEqual(
[anc.name for anc in klass.ancestors()], ["X", "tuple", "object"]
[anc.name for anc in klass.ancestors()],
["X", "tuple", "object", "NamedTuple"],
)
for anc in klass.ancestors():
self.assertFalse(anc.parent is None)
Expand Down Expand Up @@ -1611,7 +1612,7 @@ class Example(NamedTuple):
"""
)
inferred = next(result.infer())
self.assertIsInstance(inferred, astroid.Instance)
self.assertIsInstance(inferred, bases.NamedTuple)

class_attr = inferred.getattr("CLASS_ATTR")[0]
self.assertIsInstance(class_attr, astroid.AssignName)
Expand Down Expand Up @@ -1784,7 +1785,7 @@ def test_typing_namedtuple_dont_crash_on_no_fields(self) -> None:
"""
)
inferred = next(node.infer())
self.assertIsInstance(inferred, astroid.Instance)
self.assertIsInstance(inferred, bases.NamedTuple)

@test_utils.require_version("3.8")
def test_typed_dict(self):
Expand Down Expand Up @@ -3131,7 +3132,7 @@ def test_http_client_brain() -> None:
"""
)
inferred = next(node.infer())
assert isinstance(inferred, astroid.Instance)
assert isinstance(inferred, bases.NamedTuple)


def test_http_status_brain() -> None:
Expand Down
8 changes: 4 additions & 4 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest

from astroid import Slice, arguments
from astroid import Slice, arguments, bases
from astroid import decorators as decoratorsmod
from astroid import helpers, nodes, objects, test_utils, util
from astroid.arguments import CallSite
Expand Down Expand Up @@ -2193,8 +2193,8 @@ def collections(self):

"""
ast = parse(code, __name__)
bases = ast["Second"].bases[0]
inferred = next(bases.infer())
base_classes = ast["Second"].bases[0]
inferred = next(base_classes.infer())
self.assertTrue(inferred)
self.assertIsInstance(inferred, nodes.ClassDef)
self.assertEqual(inferred.qname(), "collections.Counter")
Expand Down Expand Up @@ -6249,7 +6249,7 @@ def test_inferaugassign_picking_parent_instead_of_stmt() -> None:
# as a string.
node = extract_node(code)
inferred = next(node.infer())
assert isinstance(inferred, Instance)
assert isinstance(inferred, bases.NamedTuple)
assert inferred.name == "SomeClass"


Expand Down
10 changes: 5 additions & 5 deletions tests/unittest_object_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import astroid
from astroid import builder, nodes, objects, test_utils, util
from astroid import bases, builder, nodes, objects, test_utils, util
from astroid.const import PY311_PLUS
from astroid.exceptions import InferenceError

Expand Down Expand Up @@ -203,9 +203,9 @@ class C(A): pass
called_mro = next(ast_nodes[5].infer())
self.assertEqual(called_mro.elts, mro.elts)

bases = next(ast_nodes[6].infer())
self.assertIsInstance(bases, astroid.Tuple)
self.assertEqual([cls.name for cls in bases.elts], ["object"])
bases_classes = next(ast_nodes[6].infer())
self.assertIsInstance(bases_classes, astroid.Tuple)
self.assertEqual([cls.name for cls in bases_classes.elts], ["object"])

cls = next(ast_nodes[7].infer())
self.assertIsInstance(cls, astroid.ClassDef)
Expand Down Expand Up @@ -694,7 +694,7 @@ def foo():
self.assertIsInstance(wrapped, astroid.FunctionDef)
self.assertEqual(wrapped.name, "foo")
cache_info = next(ast_nodes[2].infer())
self.assertIsInstance(cache_info, astroid.Instance)
self.assertIsInstance(cache_info, bases.NamedTuple)


if __name__ == "__main__":
Expand Down