diff --git a/astroid/bases.py b/astroid/bases.py index 4a0d152656..7611a77c0d 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -718,6 +718,10 @@ def __str__(self) -> str: class AsyncGenerator(Generator): """Special node representing an async generator.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + AsyncGenerator.special_attributes = objectmodel.AsyncGeneratorModel() + def pytype(self) -> Literal["builtins.async_generator"]: return "builtins.async_generator" diff --git a/astroid/interpreter/objectmodel.py b/astroid/interpreter/objectmodel.py index 1fcff9e7f3..7532db5edb 100644 --- a/astroid/interpreter/objectmodel.py +++ b/astroid/interpreter/objectmodel.py @@ -695,19 +695,19 @@ def attr___self__(self): class GeneratorModel(FunctionModel, ContextManagerModel): - def __new__(cls, *args, **kwargs): + def __init__(self): # Append the values from the GeneratorType unto this object. - ret = super().__new__(cls, *args, **kwargs) + super().__init__() generator = AstroidManager().builtins_module["generator"] for name, values in generator.locals.items(): method = values[0] + if isinstance(method, nodes.FunctionDef): + method = bases.BoundMethod(method, _get_bound_node(self)) def patched(cls, meth=method): return meth - setattr(type(ret), IMPL_PREFIX + name, property(patched)) - - return ret + setattr(type(self), IMPL_PREFIX + name, property(patched)) @property def attr___name__(self): @@ -724,24 +724,20 @@ def attr___doc__(self): class AsyncGeneratorModel(GeneratorModel): - def __new__(cls, *args, **kwargs): + def __init__(self): # Append the values from the AGeneratorType unto this object. - ret = super().__new__(cls, *args, **kwargs) + super().__init__() astroid_builtins = AstroidManager().builtins_module - generator = astroid_builtins.get("async_generator") - if generator is None: - # Make it backward compatible. - generator = astroid_builtins.get("generator") - + generator = astroid_builtins["async_generator"] for name, values in generator.locals.items(): method = values[0] + if isinstance(method, nodes.FunctionDef): + method = bases.BoundMethod(method, _get_bound_node(self)) def patched(cls, meth=method): return meth - setattr(type(ret), IMPL_PREFIX + name, property(patched)) - - return ret + setattr(type(self), IMPL_PREFIX + name, property(patched)) class InstanceModel(ObjectModel): diff --git a/astroid/raw_building.py b/astroid/raw_building.py index 1306838064..6343d51cc4 100644 --- a/astroid/raw_building.py +++ b/astroid/raw_building.py @@ -627,9 +627,8 @@ def _astroid_bootstrapping() -> None: col_offset=0, end_lineno=0, end_col_offset=0, - parent=nodes.Unknown(), + parent=astroid_builtin, ) - _GeneratorType.parent = astroid_builtin generator_doc_node = ( nodes.Const(value=types.GeneratorType.__doc__) if types.GeneratorType.__doc__ @@ -651,9 +650,8 @@ def _astroid_bootstrapping() -> None: col_offset=0, end_lineno=0, end_col_offset=0, - parent=nodes.Unknown(), + parent=astroid_builtin, ) - _AsyncGeneratorType.parent = astroid_builtin async_generator_doc_node = ( nodes.Const(value=types.AsyncGeneratorType.__doc__) if types.AsyncGeneratorType.__doc__ diff --git a/tests/test_lookup.py b/tests/test_lookup.py index b452d62894..bcee8f6746 100644 --- a/tests/test_lookup.py +++ b/tests/test_lookup.py @@ -322,24 +322,6 @@ class _Inner: self.assertEqual(len(name.lookup("x")[1]), 1, repr(name)) self.assertEqual(name.lookup("x")[1][0].lineno, 3, repr(name)) - def test_generator_attributes(self) -> None: - tree = builder.parse( - """ - def count(): - "test" - yield 0 - - iterer = count() - num = iterer.next() - """ - ) - next_node = tree.body[2].value.func - gener = next_node.expr.inferred()[0] - self.assertIsInstance(gener.getattr("__next__")[0], nodes.FunctionDef) - self.assertIsInstance(gener.getattr("send")[0], nodes.FunctionDef) - self.assertIsInstance(gener.getattr("throw")[0], nodes.FunctionDef) - self.assertIsInstance(gener.getattr("close")[0], nodes.FunctionDef) - def test_explicit___name__(self) -> None: code = """ class Pouet: diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 64cae2f676..cc25893963 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1454,7 +1454,7 @@ def test(self): assert bool(inferred.is_generator()) -class AsyncGeneratorTest: +class AsyncGeneratorTest(unittest.TestCase): def test_async_generator(self): node = astroid.extract_node( """ @@ -1472,23 +1472,6 @@ async def a_iter(n): assert inferred.pytype() == "builtins.async_generator" assert inferred.display_type() == "AsyncGenerator" - def test_async_generator_is_generator_on_older_python(self): - node = astroid.extract_node( - """ - async def a_iter(n): - for i in range(1, n + 1): - yield i - await asyncio.sleep(1) - a_iter(2) #@ - """ - ) - inferred = next(node.infer()) - assert isinstance(inferred, bases.Generator) - assert inferred.getattr("__iter__") - assert inferred.getattr("__next__") - assert inferred.pytype() == "builtins.generator" - assert inferred.display_type() == "Generator" - def test_f_string_correct_line_numbering() -> None: """Test that we generate correct line numbers for f-strings."""