Skip to content

gh-89687: fix get_type_hints with dataclasses __init__ generation #29158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
23 changes: 20 additions & 3 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,25 @@ def _field_assign(frozen, name, value, self_name):
return f' {self_name}.{name}={value}'


def _field_init(f, frozen, globals, self_name, slots):
def _field_init(f, frozen, globals, self_name, slots, module):
# Return the text of the line in the body of __init__ that will
# initialize this field.

if f.init and isinstance(f.type, str):
from typing import ForwardRef # `typing` is a heavy import
# We need to resolve this string type into a real `ForwardRef` object,
# because otherwise we might end up with unsolvable annotations.
# For example:
# def __init__(self, d: collections.OrderedDict) -> None:
# We won't be able to resolve `collections.OrderedDict`
# with wrong `module` param, when placed in a different module. #45524
try:
f.type = ForwardRef(f.type, module=module, is_class=True)
except SyntaxError:
# We don't want to fail class creation
# when `ForwardRef` cannot be constructed.
pass

default_name = f'__dataclass_dflt_{f.name}__'
if f.default_factory is not MISSING:
if f.init:
Expand Down Expand Up @@ -612,7 +627,7 @@ def _init_param(f):


def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
self_name, func_builder, slots):
self_name, func_builder, slots, module):
# fields contains both real fields and InitVar pseudo-fields.

# Make sure we don't have fields without defaults following fields
Expand All @@ -639,7 +654,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,

body_lines = []
for f in fields:
line = _field_init(f, frozen, locals, self_name, slots)
line = _field_init(f, frozen, locals, self_name, slots, module)
# line is None means that this field doesn't require
# initialization (it's a pseudo-field). Just skip it.
if line:
Expand Down Expand Up @@ -959,6 +974,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
# where some dataclasses does not have any bases with `_FIELDS`
all_frozen_bases = None
has_dataclass_bases = False

for b in cls.__mro__[-1:0:-1]:
# Only process classes that have been processed by our
# decorator. That is, they have a _FIELDS attribute.
Expand Down Expand Up @@ -1089,6 +1105,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
else 'self',
func_builder,
slots,
cls.__module__,
)

_set_new_attribute(cls, '__replace__', _replace)
Expand Down
125 changes: 125 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4113,6 +4113,131 @@ def test_text_annotations(self):
{'foo': dataclass_textanno.Foo,
'return': type(None)})

def test_dataclass_from_another_module(self):
# see bpo-45524
from test.test_dataclasses import dataclass_textanno
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno.Bar):
pass

@dataclass(init=False)
class WithInitFalse(dataclass_textanno.Bar):
pass

@dataclass(init=False)
class CustomInit(dataclass_textanno.Bar):
def __init__(self, foo: dataclass_textanno.Foo) -> None:
pass

@dataclass
class FutureInitChild(dataclass_textanno.WithFutureInit):
pass

classes = [
Default,
WithInitFalse,
CustomInit,
dataclass_textanno.WithFutureInit,
FutureInitChild,
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{'foo': dataclass_textanno.Foo},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{'foo': dataclass_textanno.Foo, 'return': type(None)},
)

def test_dataclass_from_proxy_module(self):
# see bpo-45524
from test.test_dataclasses import dataclass_textanno
from test.test_dataclasses import dataclass_textanno2
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno2.Child):
pass

@dataclass(init=False)
class WithInitFalse(dataclass_textanno2.Child):
pass

@dataclass(init=False)
class CustomInit(dataclass_textanno2.Child):
def __init__(
self,
foo: dataclass_textanno.Foo,
custom: dataclass_textanno2.Custom,
) -> None:
pass

@dataclass
class FutureInitChild(dataclass_textanno2.WithFutureInit):
pass

classes = [
Default,
WithInitFalse,
CustomInit,
dataclass_textanno2.WithFutureInit,
FutureInitChild,
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{
'foo': dataclass_textanno.Foo,
'custom': dataclass_textanno2.Custom,
},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{
'foo': dataclass_textanno.Foo,
'custom': dataclass_textanno2.Custom,
'return': type(None),
},
)

def test_dataclass_proxy_modules_matching_name_override(self):
# see bpo-45524
from test.test_dataclasses import dataclass_textanno2
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno2.WithMatchingNameOverride):
pass

classes = [
Default,
dataclass_textanno2.WithMatchingNameOverride
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{
'foo': dataclass_textanno2.Foo,
},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{
'foo': dataclass_textanno2.Foo,
'return': type(None),
},
)



ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)])
ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass',
Expand Down
6 changes: 6 additions & 0 deletions Lib/test/test_dataclasses/dataclass_textanno.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ class Foo:
@dataclasses.dataclass
class Bar:
foo: Foo


@dataclasses.dataclass(init=False)
class WithFutureInit(Bar):
def __init__(self, foo: Foo) -> None:
pass
30 changes: 30 additions & 0 deletions Lib/test/test_dataclasses/dataclass_textanno2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

import dataclasses

# We need to be sure that `Foo` is not in scope
from test.test_dataclasses import dataclass_textanno


class Custom:
pass


@dataclasses.dataclass
class Child(dataclass_textanno.Bar):
custom: Custom


class Foo: # matching name with `dataclass_testanno.Foo`
pass


@dataclasses.dataclass
class WithMatchingNameOverride(dataclass_textanno.Bar):
foo: Foo # Existing `foo` annotation should be overridden


@dataclasses.dataclass(init=False)
class WithFutureInit(Child):
def __init__(self, foo: dataclass_textanno.Foo, custom: Custom) -> None:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``get_type_hints()`` failure on ``@dataclass`` hierarchies in different
modules.
Loading