Skip to content

gh-109870: Combine exec calls in dataclass #110186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 73 additions & 61 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,33 +446,39 @@ def _tuple_str(obj_name, fields):
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'


def _create_fn(name, args, body, *, globals=None, locals=None,
return_type=MISSING):
def _create_fn_def(name, args, body, *, locals=None, return_type=MISSING):
# Note that we may mutate locals. Callers beware!
# The only callers are internal to this module, so no
# worries about external callers.
if locals is None:
locals = {}
return_annotation = ''
if return_type is not MISSING:
locals['__dataclass_return_type__'] = return_type
return_annotation = '->__dataclass_return_type__'
fn_name = name.replace("__", "")
locals[f'__dataclass_{fn_name}_return_type__'] = return_type
return_annotation = f'->__dataclass_{fn_name}_return_type__'
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)

# Compute the text of the entire function.
txt = f' def {name}({args}){return_annotation}:\n{body}'
txt = f'def {name}({args}){return_annotation}:\n{body}'

return (name, txt, locals)

def _exec_fn_defs(fn_defs, globals=None):
# Free variables in exec are resolved in the global namespace.
# The global namespace we have is user-provided, so we can't modify it for
# our purposes. So we put the things we need into locals and introduce a
# scope to allow the function we're creating to close over them.
local_vars = ', '.join(locals.keys())
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
locals_dict = {k: v for _, _, locals_ in fn_defs
for k, v in locals_.items()}
local_vars = ', '.join(locals_dict.keys())
fn_names = ", ".join(name for name, _, _ in fn_defs)
txt = "\n".join(f" {txt}" for _, txt, _ in fn_defs)
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {fn_names}"
ns = {}
exec(txt, globals, ns)
return ns['__create_fn__'](**locals)

return ns['__create_fn__'](**locals_dict)

def _field_assign(frozen, name, value, self_name):
# If we're a frozen class, then assign to our fields in __init__
Expand Down Expand Up @@ -566,7 +572,7 @@ def _init_param(f):


def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
self_name, globals, slots):
self_name, slots):
# fields contains both real fields and InitVar pseudo-fields.

# Make sure we don't have fields without defaults following fields
Expand Down Expand Up @@ -616,68 +622,61 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
# (instead of just concatenting the lists together).
_init_params += ['*']
_init_params += [_init_param(f) for f in kw_only_fields]
return _create_fn('__init__',
return _create_fn_def('__init__',
[self_name] + _init_params,
body_lines,
locals=locals,
globals=globals,
return_type=None)


def _repr_fn(fields, globals):
fn = _create_fn('__repr__',
def _repr_fn(fields):
return _create_fn_def('__repr__',
('self',),
['return f"{self.__class__.__qualname__}(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
')"'],
globals=globals)
return _recursive_repr(fn)
')"'],)


def _frozen_get_del_attr(cls, fields, globals):
def _frozen_get_del_attr(cls, fields):
locals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError}
condition = 'type(self) is cls'
if fields:
condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}'
return (_create_fn('__setattr__',
return (_create_fn_def('__setattr__',
('self', 'name', 'value'),
(f'if {condition}:',
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
f'super(cls, self).__setattr__(name, value)'),
locals=locals,
globals=globals),
_create_fn('__delattr__',
locals=locals),
_create_fn_def('__delattr__',
('self', 'name'),
(f'if {condition}:',
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
f'super(cls, self).__delattr__(name)'),
locals=locals,
globals=globals),
locals=locals),
)


def _cmp_fn(name, op, self_tuple, other_tuple, globals):
def _cmp_fn(name, op, self_tuple, other_tuple):
# Create a comparison function. If the fields in the object are
# named 'x' and 'y', then self_tuple is the string
# '(self.x,self.y)' and other_tuple is the string
# '(other.x,other.y)'.

return _create_fn(name,
return _create_fn_def(name,
('self', 'other'),
[ 'if other.__class__ is self.__class__:',
f' return {self_tuple}{op}{other_tuple}',
'return NotImplemented'],
globals=globals)
'return NotImplemented'],)


def _hash_fn(fields, globals):
def _hash_fn(fields):
self_tuple = _tuple_str('self', fields)
return _create_fn('__hash__',
return _create_fn_def('__hash__',
('self',),
[f'return hash({self_tuple})'],
globals=globals)
[f'return hash({self_tuple})'],)


def _is_classvar(a_type, typing):
Expand Down Expand Up @@ -855,7 +854,7 @@ def _get_field(cls, a_name, a_type, default_kw_only):
return f

def _set_qualname(cls, value):
# Ensure that the functions returned from _create_fn uses the proper
# Ensure that the functions returned from _exec_fn_defs uses the proper
# __qualname__ (the class they belong to).
if isinstance(value, FunctionType):
value.__qualname__ = f"{cls.__qualname__}.{value.__name__}"
Expand All @@ -879,9 +878,9 @@ def _set_new_attribute(cls, name, value):
def _hash_set_none(cls, fields, globals):
return None

def _hash_add(cls, fields, globals):
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
return _set_qualname(cls, _hash_fn(flds, globals))
class _HASH_ADD:
pass
_hash_add = _HASH_ADD

def _hash_exception(cls, fields, globals):
# Raise an exception.
Expand Down Expand Up @@ -925,6 +924,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
# derived class fields overwrite base class fields, but the order
# is defined by the base class, which is found first.
fields = {}
fn_defs = [] # store txt defs to exec combined

if cls.__module__ in sys.modules:
globals = sys.modules[cls.__module__].__dict__
Expand Down Expand Up @@ -1059,8 +1059,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
# Does this class have a post-init function?
has_post_init = hasattr(cls, _POST_INIT_NAME)

_set_new_attribute(cls, '__init__',
_init_fn(all_init_fields,
fn_defs.append(_init_fn(all_init_fields,
std_init_fields,
kw_only_init_fields,
frozen,
Expand All @@ -1070,7 +1069,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
# if possible.
'__dataclass_self__' if 'self' in fields
else 'self',
globals,
slots,
))
_set_new_attribute(cls, '__replace__', _replace)
Expand All @@ -1081,7 +1079,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,

if repr:
flds = [f for f in field_list if f.repr]
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
fn_defs.append(_repr_fn(flds))

if eq:
# Create __eq__ method. There's no need for a __ne__ method,
Expand All @@ -1092,41 +1090,55 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
body = [f'if other.__class__ is self.__class__:',
f' return {field_comparisons}',
f'return NotImplemented']
func = _create_fn('__eq__',
('self', 'other'),
body,
globals=globals)
_set_new_attribute(cls, '__eq__', func)
fn_defs.append(_create_fn_def('__eq__',('self', 'other'), body,))

if order:
# Create and set the ordering methods.
flds = [f for f in field_list if f.compare]
self_tuple = _tuple_str('self', flds)
other_tuple = _tuple_str('other', flds)
for name, op in [('__lt__', '<'),
('__le__', '<='),
('__gt__', '>'),
('__ge__', '>='),
]:
if _set_new_attribute(cls, name,
_cmp_fn(name, op, self_tuple, other_tuple,
globals=globals)):
raise TypeError(f'Cannot overwrite attribute {name} '
f'in class {cls.__name__}. Consider using '
'functools.total_ordering')
order_flds = {'__lt__' : '<',
'__le__' : '<=',
'__gt__' : '>',
'__ge__' : '>=',
}
for name, op in order_flds.items():
fn_defs.append(_cmp_fn(name, op, self_tuple, other_tuple))

if frozen:
for fn in _frozen_get_del_attr(cls, field_list, globals):
if _set_new_attribute(cls, fn.__name__, fn):
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')
fn_defs.extend(_frozen_get_del_attr(cls, field_list))

# Decide if/how we're going to create a hash function.
hash_action = _hash_action[bool(unsafe_hash),
bool(eq),
bool(frozen),
has_explicit_hash]
if hash_action:

if hash_action == _hash_add:
flds = [f for f in field_list if (f.compare if f.hash is None else f.hash)]
fn_defs.append(_hash_fn(field_list))
hash_action = None # assign when iterating

# exec functions and assign
functions_objects = _exec_fn_defs(fn_defs, globals=globals)
for fn in functions_objects:
name = fn.__name__
if name == '__repr__':
fn = _recursive_repr(fn)

if name == '__hash__':
cls.__hash__ = _set_qualname(cls, fn)
else:
if _set_new_attribute(cls, name, fn):
if order and name in order_flds:
raise TypeError(f'Cannot overwrite attribute {name} '
f'in class {cls.__name__}. Consider using '
'functools.total_ordering')
elif frozen and name in ['__setattr__','__delattr__']:
raise TypeError(f'Cannot overwrite attribute {name} '
f'in class {cls.__name__}')

if hash_action: # for _hash_set_none, _hash_exception
# No need to call _set_new_attribute here, since by the time
# we're here the overwriting is unconditional.
cls.__hash__ = hash_action(cls, field_list, globals)
Expand Down