From 49fcf9f2e6ffd10253a44c1db5a52cd8bd0ff2f9 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sun, 12 Apr 2020 17:45:46 -0700 Subject: [PATCH] Allow fast ObjectType creation based on dataclasses (#1157) * Allow fast ObjectType creation based on dataclasses * Fixed Python 3.8 integration * Added repr and eq methods to ObjectType containers * Reformatted code * Fixed mypy issue * Removed unused __init__ for ObjectType containers * Use black in dataclasses * Use latest black verison on precommit --- graphene/pyutils/dataclasses.py | 1260 +++++++++++++++++++++++ graphene/types/base.py | 5 +- graphene/types/objecttype.py | 77 +- graphene/types/tests/test_objecttype.py | 26 +- 4 files changed, 1321 insertions(+), 47 deletions(-) create mode 100644 graphene/pyutils/dataclasses.py diff --git a/graphene/pyutils/dataclasses.py b/graphene/pyutils/dataclasses.py new file mode 100644 index 000000000..61f0ea388 --- /dev/null +++ b/graphene/pyutils/dataclasses.py @@ -0,0 +1,1260 @@ +# This is a polyfill for dataclasses +# https://docs.python.org/3/library/dataclasses.html +# Original PEP proposal: PEP 557 +# https://www.python.org/dev/peps/pep-0557/ +import re +import sys +import copy +import types +import inspect +import keyword + +__all__ = [ + "dataclass", + "field", + "Field", + "FrozenInstanceError", + "InitVar", + "MISSING", + # Helper functions. + "fields", + "asdict", + "astuple", + "make_dataclass", + "replace", + "is_dataclass", +] + +# Conditions for adding methods. The boxes indicate what action the +# dataclass decorator takes. For all of these tables, when I talk +# about init=, repr=, eq=, order=, unsafe_hash=, or frozen=, I'm +# referring to the arguments to the @dataclass decorator. When +# checking if a dunder method already exists, I mean check for an +# entry in the class's __dict__. I never check to see if an attribute +# is defined in a base class. + +# Key: +# +=========+=========================================+ +# + Value | Meaning | +# +=========+=========================================+ +# | | No action: no method is added. | +# +---------+-----------------------------------------+ +# | add | Generated method is added. | +# +---------+-----------------------------------------+ +# | raise | TypeError is raised. | +# +---------+-----------------------------------------+ +# | None | Attribute is set to None. | +# +=========+=========================================+ + +# __init__ +# +# +--- init= parameter +# | +# v | | | +# | no | yes | <--- class has __init__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + +# __repr__ +# +# +--- repr= parameter +# | +# v | | | +# | no | yes | <--- class has __repr__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + + +# __setattr__ +# __delattr__ +# +# +--- frozen= parameter +# | +# v | | | +# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__? +# +=======+=======+=======+ +# | False | | | <- the default +# +-------+-------+-------+ +# | True | add | raise | +# +=======+=======+=======+ +# Raise because not adding these methods would break the "frozen-ness" +# of the class. + +# __eq__ +# +# +--- eq= parameter +# | +# v | | | +# | no | yes | <--- class has __eq__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + +# __lt__ +# __le__ +# __gt__ +# __ge__ +# +# +--- order= parameter +# | +# v | | | +# | no | yes | <--- class has any comparison method in __dict__? +# +=======+=======+=======+ +# | False | | | <- the default +# +-------+-------+-------+ +# | True | add | raise | +# +=======+=======+=======+ +# Raise because to allow this case would interfere with using +# functools.total_ordering. + +# __hash__ + +# +------------------- unsafe_hash= parameter +# | +----------- eq= parameter +# | | +--- frozen= parameter +# | | | +# v v v | | | +# | no | yes | <--- class has explicitly defined __hash__ +# +=======+=======+=======+========+========+ +# | False | False | False | | | No __eq__, use the base class __hash__ +# +-------+-------+-------+--------+--------+ +# | False | False | True | | | No __eq__, use the base class __hash__ +# +-------+-------+-------+--------+--------+ +# | False | True | False | None | | <-- the default, not hashable +# +-------+-------+-------+--------+--------+ +# | False | True | True | add | | Frozen, so hashable, allows override +# +-------+-------+-------+--------+--------+ +# | True | False | False | add | raise | Has no __eq__, but hashable +# +-------+-------+-------+--------+--------+ +# | True | False | True | add | raise | Has no __eq__, but hashable +# +-------+-------+-------+--------+--------+ +# | True | True | False | add | raise | Not frozen, but hashable +# +-------+-------+-------+--------+--------+ +# | True | True | True | add | raise | Frozen, so hashable +# +=======+=======+=======+========+========+ +# For boxes that are blank, __hash__ is untouched and therefore +# inherited from the base class. If the base is object, then +# id-based hashing is used. +# +# Note that a class may already have __hash__=None if it specified an +# __eq__ method in the class body (not one that was created by +# @dataclass). +# +# See _hash_action (below) for a coded version of this table. + + +# Raised when an attempt is made to modify a frozen class. +class FrozenInstanceError(AttributeError): + pass + + +# A sentinel object for default values to signal that a default +# factory will be used. This is given a nice repr() which will appear +# in the function signature of dataclasses' constructors. +class _HAS_DEFAULT_FACTORY_CLASS: + def __repr__(self): + return "" + + +_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS() + +# A sentinel object to detect if a parameter is supplied or not. Use +# a class to give it a better repr. +class _MISSING_TYPE: + pass + + +MISSING = _MISSING_TYPE() + +# Since most per-field metadata will be unused, create an empty +# read-only proxy that can be shared among all fields. +_EMPTY_METADATA = types.MappingProxyType({}) + +# Markers for the various kinds of fields and pseudo-fields. +class _FIELD_BASE: + def __init__(self, name): + self.name = name + + def __repr__(self): + return self.name + + +_FIELD = _FIELD_BASE("_FIELD") +_FIELD_CLASSVAR = _FIELD_BASE("_FIELD_CLASSVAR") +_FIELD_INITVAR = _FIELD_BASE("_FIELD_INITVAR") + +# The name of an attribute on the class where we store the Field +# objects. Also used to check if a class is a Data Class. +_FIELDS = "__dataclass_fields__" + +# The name of an attribute on the class that stores the parameters to +# @dataclass. +_PARAMS = "__dataclass_params__" + +# The name of the function, that if it exists, is called at the end of +# __init__. +_POST_INIT_NAME = "__post_init__" + +# String regex that string annotations for ClassVar or InitVar must match. +# Allows "identifier.identifier[" or "identifier[". +# https://bugs.python.org/issue33453 for details. +_MODULE_IDENTIFIER_RE = re.compile(r"^(?:\s*(\w+)\s*\.)?\s*(\w+)") + + +class _InitVarMeta(type): + def __getitem__(self, params): + return self + + +class InitVar(metaclass=_InitVarMeta): + pass + + +# Instances of Field are only ever created from within this module, +# and only from the field() function, although Field instances are +# exposed externally as (conceptually) read-only objects. +# +# name and type are filled in after the fact, not in __init__. +# They're not known at the time this class is instantiated, but it's +# convenient if they're available later. +# +# When cls._FIELDS is filled in with a list of Field objects, the name +# and type fields will have been populated. +class Field: + __slots__ = ( + "name", + "type", + "default", + "default_factory", + "repr", + "hash", + "init", + "compare", + "metadata", + "_field_type", # Private: not to be used by user code. + ) + + def __init__(self, default, default_factory, init, repr, hash, compare, metadata): + self.name = None + self.type = None + self.default = default + self.default_factory = default_factory + self.init = init + self.repr = repr + self.hash = hash + self.compare = compare + self.metadata = ( + _EMPTY_METADATA + if metadata is None or len(metadata) == 0 + else types.MappingProxyType(metadata) + ) + self._field_type = None + + def __repr__(self): + return ( + "Field(" + f"name={self.name!r}," + f"type={self.type!r}," + f"default={self.default!r}," + f"default_factory={self.default_factory!r}," + f"init={self.init!r}," + f"repr={self.repr!r}," + f"hash={self.hash!r}," + f"compare={self.compare!r}," + f"metadata={self.metadata!r}," + f"_field_type={self._field_type}" + ")" + ) + + # This is used to support the PEP 487 __set_name__ protocol in the + # case where we're using a field that contains a descriptor as a + # defaul value. For details on __set_name__, see + # https://www.python.org/dev/peps/pep-0487/#implementation-details. + # + # Note that in _process_class, this Field object is overwritten + # with the default value, so the end result is a descriptor that + # had __set_name__ called on it at the right time. + def __set_name__(self, owner, name): + func = getattr(type(self.default), "__set_name__", None) + if func: + # There is a __set_name__ method on the descriptor, call + # it. + func(self.default, owner, name) + + +class _DataclassParams: + __slots__ = ( + "init", + "repr", + "eq", + "order", + "unsafe_hash", + "frozen", + ) + + def __init__(self, init, repr, eq, order, unsafe_hash, frozen): + self.init = init + self.repr = repr + self.eq = eq + self.order = order + self.unsafe_hash = unsafe_hash + self.frozen = frozen + + def __repr__(self): + return ( + "_DataclassParams(" + f"init={self.init!r}," + f"repr={self.repr!r}," + f"eq={self.eq!r}," + f"order={self.order!r}," + f"unsafe_hash={self.unsafe_hash!r}," + f"frozen={self.frozen!r}" + ")" + ) + + +# This function is used instead of exposing Field creation directly, +# so that a type checker can be told (via overloads) that this is a +# function whose type depends on its parameters. +def field( + *, + default=MISSING, + default_factory=MISSING, + init=True, + repr=True, + hash=None, + compare=True, + metadata=None, +): + """Return an object to identify dataclass fields. + + default is the default value of the field. default_factory is a + 0-argument function called to initialize a field's value. If init + is True, the field will be a parameter to the class's __init__() + function. If repr is True, the field will be included in the + object's repr(). If hash is True, the field will be included in + the object's hash(). If compare is True, the field will be used + in comparison functions. metadata, if specified, must be a + mapping which is stored but not otherwise examined by dataclass. + + It is an error to specify both default and default_factory. + """ + + if default is not MISSING and default_factory is not MISSING: + raise ValueError("cannot specify both default and default_factory") + return Field(default, default_factory, init, repr, hash, compare, metadata) + + +def _tuple_str(obj_name, fields): + # Return a string representing each field of obj_name as a tuple + # member. So, if fields is ['x', 'y'] and obj_name is "self", + # return "(self.x,self.y)". + + # Special case for the 0-tuple. + if not fields: + return "()" + # Note the trailing comma, needed if this turns out to be a 1-tuple. + return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' + + +def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING): + # Note that we mutate locals when exec() is called. Caller + # beware! The only callers are internal to this module, so no + # worries about external callers. + if locals is None: + locals = {} + return_annotation = "" + if return_type is not MISSING: + locals["_return_type"] = return_type + return_annotation = "->_return_type" + args = ",".join(args) + body = "\n".join(f" {b}" for b in body) + + # Compute the text of the entire function. + txt = f"def {name}({args}){return_annotation}:\n{body}" + + exec(txt, globals, locals) + return locals[name] + + +def _field_assign(frozen, name, value, self_name): + # If we're a frozen class, then assign to our fields in __init__ + # via object.__setattr__. Otherwise, just use a simple + # assignment. + # + # self_name is what "self" is called in this function: don't + # hard-code "self", since that might be a field name. + if frozen: + return f"object.__setattr__({self_name},{name!r},{value})" + return f"{self_name}.{name}={value}" + + +def _field_init(f, frozen, globals, self_name): + # Return the text of the line in the body of __init__ that will + # initialize this field. + + default_name = f"_dflt_{f.name}" + if f.default_factory is not MISSING: + if f.init: + # This field has a default factory. If a parameter is + # given, use it. If not, call the factory. + globals[default_name] = f.default_factory + value = ( + f"{default_name}() " + f"if {f.name} is _HAS_DEFAULT_FACTORY " + f"else {f.name}" + ) + else: + # This is a field that's not in the __init__ params, but + # has a default factory function. It needs to be + # initialized here by calling the factory function, + # because there's no other way to initialize it. + + # For a field initialized with a default=defaultvalue, the + # class dict just has the default value + # (cls.fieldname=defaultvalue). But that won't work for a + # default factory, the factory must be called in __init__ + # and we must assign that to self.fieldname. We can't + # fall back to the class dict's value, both because it's + # not set, and because it might be different per-class + # (which, after all, is why we have a factory function!). + + globals[default_name] = f.default_factory + value = f"{default_name}()" + else: + # No default factory. + if f.init: + if f.default is MISSING: + # There's no default, just do an assignment. + value = f.name + elif f.default is not MISSING: + globals[default_name] = f.default + value = f.name + else: + # This field does not need initialization. Signify that + # to the caller by returning None. + return None + + # Only test this now, so that we can create variables for the + # default. However, return None to signify that we're not going + # to actually do the assignment statement for InitVars. + if f._field_type == _FIELD_INITVAR: + return None + + # Now, actually generate the field assignment. + return _field_assign(frozen, f.name, value, self_name) + + +def _init_param(f): + # Return the __init__ parameter string for this field. For + # example, the equivalent of 'x:int=3' (except instead of 'int', + # reference a variable set to int, and instead of '3', reference a + # variable set to 3). + if f.default is MISSING and f.default_factory is MISSING: + # There's no default, and no default_factory, just output the + # variable name and type. + default = "" + elif f.default is not MISSING: + # There's a default, this will be the name that's used to look + # it up. + default = f"=_dflt_{f.name}" + elif f.default_factory is not MISSING: + # There's a factory function. Set a marker. + default = "=_HAS_DEFAULT_FACTORY" + return f"{f.name}:_type_{f.name}{default}" + + +def _init_fn(fields, frozen, has_post_init, self_name): + # fields contains both real fields and InitVar pseudo-fields. + + # Make sure we don't have fields without defaults following fields + # with defaults. This actually would be caught when exec-ing the + # function source code, but catching it here gives a better error + # message, and future-proofs us in case we build up the function + # using ast. + seen_default = False + for f in fields: + # Only consider fields in the __init__ call. + if f.init: + if not (f.default is MISSING and f.default_factory is MISSING): + seen_default = True + elif seen_default: + raise TypeError( + f"non-default argument {f.name!r} " "follows default argument" + ) + + globals = {"MISSING": MISSING, "_HAS_DEFAULT_FACTORY": _HAS_DEFAULT_FACTORY} + + body_lines = [] + for f in fields: + line = _field_init(f, frozen, globals, self_name) + # line is None means that this field doesn't require + # initialization (it's a pseudo-field). Just skip it. + if line: + body_lines.append(line) + + # Does this class have a post-init function? + if has_post_init: + params_str = ",".join(f.name for f in fields if f._field_type is _FIELD_INITVAR) + body_lines.append(f"{self_name}.{_POST_INIT_NAME}({params_str})") + + # If no body lines, use 'pass'. + if not body_lines: + body_lines = ["pass"] + + locals = {f"_type_{f.name}": f.type for f in fields} + return _create_fn( + "__init__", + [self_name] + [_init_param(f) for f in fields if f.init], + body_lines, + locals=locals, + globals=globals, + return_type=None, + ) + + +def _repr_fn(fields): + return _create_fn( + "__repr__", + ("self",), + [ + 'return self.__class__.__qualname__ + f"(' + + ", ".join([f"{f.name}={{self.{f.name}!r}}" for f in fields]) + + ')"' + ], + ) + + +def _frozen_get_del_attr(cls, fields): + # XXX: globals is modified on the first call to _create_fn, then + # the modified version is used in the second call. Is this okay? + globals = {"cls": cls, "FrozenInstanceError": FrozenInstanceError} + if fields: + fields_str = "(" + ",".join(repr(f.name) for f in fields) + ",)" + else: + # Special case for the zero-length tuple. + fields_str = "()" + return ( + _create_fn( + "__setattr__", + ("self", "name", "value"), + ( + f"if type(self) is cls or name in {fields_str}:", + ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', + f"super(cls, self).__setattr__(name, value)", + ), + globals=globals, + ), + _create_fn( + "__delattr__", + ("self", "name"), + ( + f"if type(self) is cls or name in {fields_str}:", + ' raise FrozenInstanceError(f"cannot delete field {name!r}")', + f"super(cls, self).__delattr__(name)", + ), + globals=globals, + ), + ) + + +def _cmp_fn(name, op, self_tuple, other_tuple): + # Create a comparison function. If the fields in the object are + # named 'x' and 'y', then self_tuple is the string + # '(self.x,self.y)' and other_tuple is the string + # '(other.x,other.y)'. + + return _create_fn( + name, + ("self", "other"), + [ + "if other.__class__ is self.__class__:", + f" return {self_tuple}{op}{other_tuple}", + "return NotImplemented", + ], + ) + + +def _hash_fn(fields): + self_tuple = _tuple_str("self", fields) + return _create_fn("__hash__", ("self",), [f"return hash({self_tuple})"]) + + +def _is_classvar(a_type, typing): + # This test uses a typing internal class, but it's the best way to + # test if this is a ClassVar. + return type(a_type) is typing._ClassVar + + +def _is_initvar(a_type, dataclasses): + # The module we're checking against is the module we're + # currently in (dataclasses.py). + return a_type is dataclasses.InitVar + + +def _is_type(annotation, cls, a_module, a_type, is_type_predicate): + # Given a type annotation string, does it refer to a_type in + # a_module? For example, when checking that annotation denotes a + # ClassVar, then a_module is typing, and a_type is + # typing.ClassVar. + + # It's possible to look up a_module given a_type, but it involves + # looking in sys.modules (again!), and seems like a waste since + # the caller already knows a_module. + + # - annotation is a string type annotation + # - cls is the class that this annotation was found in + # - a_module is the module we want to match + # - a_type is the type in that module we want to match + # - is_type_predicate is a function called with (obj, a_module) + # that determines if obj is of the desired type. + + # Since this test does not do a local namespace lookup (and + # instead only a module (global) lookup), there are some things it + # gets wrong. + + # With string annotations, cv0 will be detected as a ClassVar: + # CV = ClassVar + # @dataclass + # class C0: + # cv0: CV + + # But in this example cv1 will not be detected as a ClassVar: + # @dataclass + # class C1: + # CV = ClassVar + # cv1: CV + + # In C1, the code in this function (_is_type) will look up "CV" in + # the module and not find it, so it will not consider cv1 as a + # ClassVar. This is a fairly obscure corner case, and the best + # way to fix it would be to eval() the string "CV" with the + # correct global and local namespaces. However that would involve + # a eval() penalty for every single field of every dataclass + # that's defined. It was judged not worth it. + + match = _MODULE_IDENTIFIER_RE.match(annotation) + if match: + ns = None + module_name = match.group(1) + if not module_name: + # No module name, assume the class's module did + # "from dataclasses import InitVar". + ns = sys.modules.get(cls.__module__).__dict__ + else: + # Look up module_name in the class's module. + module = sys.modules.get(cls.__module__) + if module and module.__dict__.get(module_name) is a_module: + ns = sys.modules.get(a_type.__module__).__dict__ + if ns and is_type_predicate(ns.get(match.group(2)), a_module): + return True + return False + + +def _get_field(cls, a_name, a_type): + # Return a Field object for this field name and type. ClassVars + # and InitVars are also returned, but marked as such (see + # f._field_type). + + # If the default value isn't derived from Field, then it's only a + # normal default value. Convert it to a Field(). + default = getattr(cls, a_name, MISSING) + if isinstance(default, Field): + f = default + else: + if isinstance(default, types.MemberDescriptorType): + # This is a field in __slots__, so it has no default value. + default = MISSING + f = field(default=default) + + # Only at this point do we know the name and the type. Set them. + f.name = a_name + f.type = a_type + + # Assume it's a normal field until proven otherwise. We're next + # going to decide if it's a ClassVar or InitVar, everything else + # is just a normal field. + f._field_type = _FIELD + + # In addition to checking for actual types here, also check for + # string annotations. get_type_hints() won't always work for us + # (see https://github.com/python/typing/issues/508 for example), + # plus it's expensive and would require an eval for every stirng + # annotation. So, make a best effort to see if this is a ClassVar + # or InitVar using regex's and checking that the thing referenced + # is actually of the correct type. + + # For the complete discussion, see https://bugs.python.org/issue33453 + + # If typing has not been imported, then it's impossible for any + # annotation to be a ClassVar. So, only look for ClassVar if + # typing has been imported by any module (not necessarily cls's + # module). + typing = sys.modules.get("typing") + if typing: + if _is_classvar(a_type, typing) or ( + isinstance(f.type, str) + and _is_type(f.type, cls, typing, typing.ClassVar, _is_classvar) + ): + f._field_type = _FIELD_CLASSVAR + + # If the type is InitVar, or if it's a matching string annotation, + # then it's an InitVar. + if f._field_type is _FIELD: + # The module we're checking against is the module we're + # currently in (dataclasses.py). + dataclasses = sys.modules[__name__] + if _is_initvar(a_type, dataclasses) or ( + isinstance(f.type, str) + and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, _is_initvar) + ): + f._field_type = _FIELD_INITVAR + + # Validations for individual fields. This is delayed until now, + # instead of in the Field() constructor, since only here do we + # know the field name, which allows for better error reporting. + + # Special restrictions for ClassVar and InitVar. + if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR): + if f.default_factory is not MISSING: + raise TypeError(f"field {f.name} cannot have a " "default factory") + # Should I check for other field settings? default_factory + # seems the most serious to check for. Maybe add others. For + # example, how about init=False (or really, + # init=)? It makes no sense for + # ClassVar and InitVar to specify init=. + + # For real fields, disallow mutable defaults for known types. + if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): + raise ValueError( + f"mutable default {type(f.default)} for field " + f"{f.name} is not allowed: use default_factory" + ) + + return f + + +def _set_new_attribute(cls, name, value): + # Never overwrites an existing attribute. Returns True if the + # attribute already exists. + if name in cls.__dict__: + return True + setattr(cls, name, value) + return False + + +# Decide if/how we're going to create a hash function. Key is +# (unsafe_hash, eq, frozen, does-hash-exist). Value is the action to +# take. The common case is to do nothing, so instead of providing a +# function that is a no-op, use None to signify that. + + +def _hash_set_none(cls, fields): + return None + + +def _hash_add(cls, fields): + flds = [f for f in fields if (f.compare if f.hash is None else f.hash)] + return _hash_fn(flds) + + +def _hash_exception(cls, fields): + # Raise an exception. + raise TypeError(f"Cannot overwrite attribute __hash__ " f"in class {cls.__name__}") + + +# +# +-------------------------------------- unsafe_hash? +# | +------------------------------- eq? +# | | +------------------------ frozen? +# | | | +---------------- has-explicit-hash? +# | | | | +# | | | | +------- action +# | | | | | +# v v v v v +_hash_action = { + (False, False, False, False): None, + (False, False, False, True): None, + (False, False, True, False): None, + (False, False, True, True): None, + (False, True, False, False): _hash_set_none, + (False, True, False, True): None, + (False, True, True, False): _hash_add, + (False, True, True, True): None, + (True, False, False, False): _hash_add, + (True, False, False, True): _hash_exception, + (True, False, True, False): _hash_add, + (True, False, True, True): _hash_exception, + (True, True, False, False): _hash_add, + (True, True, False, True): _hash_exception, + (True, True, True, False): _hash_add, + (True, True, True, True): _hash_exception, +} +# See https://bugs.python.org/issue32929#msg312829 for an if-statement +# version of this table. + + +def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): + # Now that dicts retain insertion order, there's no reason to use + # an ordered dict. I am leveraging that ordering here, because + # derived class fields overwrite base class fields, but the order + # is defined by the base class, which is found first. + fields = {} + + setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, unsafe_hash, frozen)) + + # Find our base classes in reverse MRO order, and exclude + # ourselves. In reversed order so that more derived classes + # override earlier field definitions in base classes. As long as + # we're iterating over them, see if any are frozen. + any_frozen_base = False + has_dataclass_bases = False + for b in cls.__mro__[-1:0:-1]: + # Only process classes that have been processed by our + # decorator. That is, they have a _FIELDS attribute. + base_fields = getattr(b, _FIELDS, None) + if base_fields: + has_dataclass_bases = True + for f in base_fields.values(): + fields[f.name] = f + if getattr(b, _PARAMS).frozen: + any_frozen_base = True + + # Annotations that are defined in this class (not in base + # classes). If __annotations__ isn't present, then this class + # adds no new annotations. We use this to compute fields that are + # added by this class. + # + # Fields are found from cls_annotations, which is guaranteed to be + # ordered. Default values are from class attributes, if a field + # has a default. If the default value is a Field(), then it + # contains additional info beyond (and possibly including) the + # actual default value. Pseudo-fields ClassVars and InitVars are + # included, despite the fact that they're not real fields. That's + # dealt with later. + cls_annotations = cls.__dict__.get("__annotations__", {}) + + # Now find fields in our class. While doing so, validate some + # things, and set the default values (as class attributes) where + # we can. + cls_fields = [_get_field(cls, name, type) for name, type in cls_annotations.items()] + for f in cls_fields: + fields[f.name] = f + + # If the class attribute (which is the default value for this + # field) exists and is of type 'Field', replace it with the + # real default. This is so that normal class introspection + # sees a real default value, not a Field. + if isinstance(getattr(cls, f.name, None), Field): + if f.default is MISSING: + # If there's no default, delete the class attribute. + # This happens if we specify field(repr=False), for + # example (that is, we specified a field object, but + # no default value). Also if we're using a default + # factory. The class attribute should not be set at + # all in the post-processed class. + delattr(cls, f.name) + else: + setattr(cls, f.name, f.default) + + # Do we have any Field members that don't also have annotations? + for name, value in cls.__dict__.items(): + if isinstance(value, Field) and not name in cls_annotations: + raise TypeError(f"{name!r} is a field but has no type annotation") + + # Check rules that apply if we are derived from any dataclasses. + if has_dataclass_bases: + # Raise an exception if any of our bases are frozen, but we're not. + if any_frozen_base and not frozen: + raise TypeError("cannot inherit non-frozen dataclass from a " "frozen one") + + # Raise an exception if we're frozen, but none of our bases are. + if not any_frozen_base and frozen: + raise TypeError("cannot inherit frozen dataclass from a " "non-frozen one") + + # Remember all of the fields on our class (including bases). This + # also marks this class as being a dataclass. + setattr(cls, _FIELDS, fields) + + # Was this class defined with an explicit __hash__? Note that if + # __eq__ is defined in this class, then python will automatically + # set __hash__ to None. This is a heuristic, as it's possible + # that such a __hash__ == None was not auto-generated, but it + # close enough. + class_hash = cls.__dict__.get("__hash__", MISSING) + has_explicit_hash = not ( + class_hash is MISSING or (class_hash is None and "__eq__" in cls.__dict__) + ) + + # If we're generating ordering methods, we must be generating the + # eq methods. + if order and not eq: + raise ValueError("eq must be true if order is true") + + if init: + # Does this class have a post-init function? + has_post_init = hasattr(cls, _POST_INIT_NAME) + + # Include InitVars and regular fields (so, not ClassVars). + flds = [f for f in fields.values() if f._field_type in (_FIELD, _FIELD_INITVAR)] + _set_new_attribute( + cls, + "__init__", + _init_fn( + flds, + frozen, + has_post_init, + # The name to use for the "self" + # param in __init__. Use "self" + # if possible. + "__dataclass_self__" if "self" in fields else "self", + ), + ) + + # Get the fields as a list, and include only real fields. This is + # used in all of the following methods. + field_list = [f for f in fields.values() if f._field_type is _FIELD] + + if repr: + flds = [f for f in field_list if f.repr] + _set_new_attribute(cls, "__repr__", _repr_fn(flds)) + + if eq: + # Create _eq__ method. There's no need for a __ne__ method, + # since python will call __eq__ and negate it. + flds = [f for f in field_list if f.compare] + self_tuple = _tuple_str("self", flds) + other_tuple = _tuple_str("other", flds) + _set_new_attribute( + cls, "__eq__", _cmp_fn("__eq__", "==", self_tuple, other_tuple) + ) + + if order: + # Create and set the ordering methods. + flds = [f for f in field_list if f.compare] + self_tuple = _tuple_str("self", flds) + other_tuple = _tuple_str("other", flds) + for name, op in [ + ("__lt__", "<"), + ("__le__", "<="), + ("__gt__", ">"), + ("__ge__", ">="), + ]: + if _set_new_attribute( + cls, name, _cmp_fn(name, op, self_tuple, other_tuple) + ): + raise TypeError( + f"Cannot overwrite attribute {name} " + f"in class {cls.__name__}. Consider using " + "functools.total_ordering" + ) + + if frozen: + for fn in _frozen_get_del_attr(cls, field_list): + if _set_new_attribute(cls, fn.__name__, fn): + raise TypeError( + f"Cannot overwrite attribute {fn.__name__} " + f"in class {cls.__name__}" + ) + + # Decide if/how we're going to create a hash function. + hash_action = _hash_action[ + bool(unsafe_hash), bool(eq), bool(frozen), has_explicit_hash + ] + if hash_action: + # No need to call _set_new_attribute here, since by the time + # we're here the overwriting is unconditional. + cls.__hash__ = hash_action(cls, field_list) + + if not getattr(cls, "__doc__"): + # Create a class doc-string. + cls.__doc__ = cls.__name__ + str(inspect.signature(cls)).replace(" -> None", "") + + return cls + + +# _cls should never be specified by keyword, so start it with an +# underscore. The presence of _cls is used to detect if this +# decorator is being called with parameters or not. +def dataclass( + _cls=None, + *, + init=True, + repr=True, + eq=True, + order=False, + unsafe_hash=False, + frozen=False, +): + """Returns the same class as was passed in, with dunder methods + added based on the fields defined in the class. + + Examines PEP 526 __annotations__ to determine fields. + + If init is true, an __init__() method is added to the class. If + repr is true, a __repr__() method is added. If order is true, rich + comparison dunder methods are added. If unsafe_hash is true, a + __hash__() method function is added. If frozen is true, fields may + not be assigned to after instance creation. + """ + + def wrap(cls): + return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen) + + # See if we're being called as @dataclass or @dataclass(). + if _cls is None: + # We're called with parens. + return wrap + + # We're called as @dataclass without parens. + return wrap(_cls) + + +def fields(class_or_instance): + """Return a tuple describing the fields of this dataclass. + + Accepts a dataclass or an instance of one. Tuple elements are of + type Field. + """ + + # Might it be worth caching this, per class? + try: + fields = getattr(class_or_instance, _FIELDS) + except AttributeError: + raise TypeError("must be called with a dataclass type or instance") + + # Exclude pseudo-fields. Note that fields is sorted by insertion + # order, so the order of the tuple is as the fields were defined. + return tuple(f for f in fields.values() if f._field_type is _FIELD) + + +def _is_dataclass_instance(obj): + """Returns True if obj is an instance of a dataclass.""" + return not isinstance(obj, type) and hasattr(obj, _FIELDS) + + +def is_dataclass(obj): + """Returns True if obj is a dataclass or an instance of a + dataclass.""" + return hasattr(obj, _FIELDS) + + +def asdict(obj, *, dict_factory=dict): + """Return the fields of a dataclass instance as a new dictionary mapping + field names to field values. + + Example usage: + + @dataclass + class C: + x: int + y: int + + c = C(1, 2) + assert asdict(c) == {'x': 1, 'y': 2} + + If given, 'dict_factory' will be used instead of built-in dict. + The function applies recursively to field values that are + dataclass instances. This will also look into built-in containers: + tuples, lists, and dicts. + """ + if not _is_dataclass_instance(obj): + raise TypeError("asdict() should be called on dataclass instances") + return _asdict_inner(obj, dict_factory) + + +def _asdict_inner(obj, dict_factory): + if _is_dataclass_instance(obj): + result = [] + for f in fields(obj): + value = _asdict_inner(getattr(obj, f.name), dict_factory) + result.append((f.name, value)) + return dict_factory(result) + elif isinstance(obj, (list, tuple)): + return type(obj)(_asdict_inner(v, dict_factory) for v in obj) + elif isinstance(obj, dict): + return type(obj)( + (_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory)) + for k, v in obj.items() + ) + else: + return copy.deepcopy(obj) + + +def astuple(obj, *, tuple_factory=tuple): + """Return the fields of a dataclass instance as a new tuple of field values. + + Example usage:: + + @dataclass + class C: + x: int + y: int + + c = C(1, 2) + assert astuple(c) == (1, 2) + + If given, 'tuple_factory' will be used instead of built-in tuple. + The function applies recursively to field values that are + dataclass instances. This will also look into built-in containers: + tuples, lists, and dicts. + """ + + if not _is_dataclass_instance(obj): + raise TypeError("astuple() should be called on dataclass instances") + return _astuple_inner(obj, tuple_factory) + + +def _astuple_inner(obj, tuple_factory): + if _is_dataclass_instance(obj): + result = [] + for f in fields(obj): + value = _astuple_inner(getattr(obj, f.name), tuple_factory) + result.append(value) + return tuple_factory(result) + elif isinstance(obj, (list, tuple)): + return type(obj)(_astuple_inner(v, tuple_factory) for v in obj) + elif isinstance(obj, dict): + return type(obj)( + (_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) + for k, v in obj.items() + ) + else: + return copy.deepcopy(obj) + + +def make_dataclass( + cls_name, + fields, + *, + bases=(), + namespace=None, + init=True, + repr=True, + eq=True, + order=False, + unsafe_hash=False, + frozen=False, +): + """Return a new dynamically created dataclass. + + The dataclass name will be 'cls_name'. 'fields' is an iterable + of either (name), (name, type) or (name, type, Field) objects. If type is + omitted, use the string 'typing.Any'. Field objects are created by + the equivalent of calling 'field(name, type [, Field-info])'. + + C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,)) + + is equivalent to: + + @dataclass + class C(Base): + x: 'typing.Any' + y: int + z: int = field(init=False) + + For the bases and namespace parameters, see the builtin type() function. + + The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to + dataclass(). + """ + + if namespace is None: + namespace = {} + else: + # Copy namespace since we're going to mutate it. + namespace = namespace.copy() + + # While we're looking through the field names, validate that they + # are identifiers, are not keywords, and not duplicates. + seen = set() + anns = {} + for item in fields: + if isinstance(item, str): + name = item + tp = "typing.Any" + elif len(item) == 2: + name, tp, = item + elif len(item) == 3: + name, tp, spec = item + namespace[name] = spec + else: + raise TypeError(f"Invalid field: {item!r}") + + if not isinstance(name, str) or not name.isidentifier(): + raise TypeError(f"Field names must be valid identifers: {name!r}") + if keyword.iskeyword(name): + raise TypeError(f"Field names must not be keywords: {name!r}") + if name in seen: + raise TypeError(f"Field name duplicated: {name!r}") + + seen.add(name) + anns[name] = tp + + namespace["__annotations__"] = anns + # We use `types.new_class()` instead of simply `type()` to allow dynamic creation + # of generic dataclassses. + cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace)) + return dataclass( + cls, + init=init, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen, + ) + + +def replace(obj, **changes): + """Return a new object replacing specified fields with new values. + + This is especially useful for frozen classes. Example usage: + + @dataclass(frozen=True) + class C: + x: int + y: int + + c = C(1, 2) + c1 = replace(c, x=3) + assert c1.x == 3 and c1.y == 2 + """ + + # We're going to mutate 'changes', but that's okay because it's a + # new dict, even if called with 'replace(obj, **my_changes)'. + + if not _is_dataclass_instance(obj): + raise TypeError("replace() should be called on dataclass instances") + + # It's an error to have init=False fields in 'changes'. + # If a field is not in 'changes', read its value from the provided obj. + + for f in getattr(obj, _FIELDS).values(): + if not f.init: + # Error if this field is specified in changes. + if f.name in changes: + raise ValueError( + f"field {f.name} is declared with " + "init=False, it cannot be specified with " + "replace()" + ) + continue + + if f.name not in changes: + changes[f.name] = getattr(obj, f.name) + + # Create the new object, which calls __init__() and + # __post_init__() (if defined), using all of the init fields we've + # added and/or left in 'changes'. If there are values supplied in + # changes that aren't fields, this will correctly raise a + # TypeError. + return obj.__class__(**changes) diff --git a/graphene/types/base.py b/graphene/types/base.py index 29d60fef1..129834b0d 100644 --- a/graphene/types/base.py +++ b/graphene/types/base.py @@ -1,6 +1,6 @@ from typing import Type -from ..utils.subclass_with_meta import SubclassWithMeta +from ..utils.subclass_with_meta import SubclassWithMeta, SubclassWithMeta_Meta from ..utils.trim_docstring import trim_docstring @@ -26,6 +26,9 @@ def __repr__(self): return f"<{self.__class__.__name__} name={repr(self.name)}>" +BaseTypeMeta = SubclassWithMeta_Meta + + class BaseType(SubclassWithMeta): @classmethod def create_type(cls, class_name, **options): diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 9b6c7d7df..c16f50b8f 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -1,8 +1,13 @@ -from .base import BaseOptions, BaseType +from .base import BaseOptions, BaseType, BaseTypeMeta from .field import Field from .interface import Interface from .utils import yank_fields_from_attrs +try: + from dataclasses import make_dataclass, field +except ImportError: + from ..pyutils.dataclasses import make_dataclass, field # type: ignore + # For static type checking with Mypy MYPY = False if MYPY: @@ -14,7 +19,34 @@ class ObjectTypeOptions(BaseOptions): interfaces = () # type: Iterable[Type[Interface]] -class ObjectType(BaseType): +class ObjectTypeMeta(BaseTypeMeta): + def __new__(cls, name, bases, namespace): + # We create this type, to then overload it with the dataclass attrs + class InterObjectType: + pass + + base_cls = super().__new__(cls, name, (InterObjectType,) + bases, namespace) + if base_cls._meta: + fields = [ + ( + key, + "typing.Any", + field( + default=field_value.default_value + if isinstance(field_value, Field) + else None + ), + ) + for key, field_value in base_cls._meta.fields.items() + ] + dataclass = make_dataclass(name, fields, bases=()) + InterObjectType.__init__ = dataclass.__init__ + InterObjectType.__eq__ = dataclass.__eq__ + InterObjectType.__repr__ = dataclass.__repr__ + return base_cls + + +class ObjectType(BaseType, metaclass=ObjectTypeMeta): """ Object Type Definition @@ -127,44 +159,3 @@ def __init_subclass_with_meta__( super(ObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options) is_type_of = None - - def __init__(self, *args, **kwargs): - # ObjectType acting as container - args_len = len(args) - fields = self._meta.fields.items() - if args_len > len(fields): - # Daft, but matches old exception sans the err msg. - raise IndexError("Number of args exceeds number of fields") - fields_iter = iter(fields) - - if not kwargs: - for val, (name, field) in zip(args, fields_iter): - setattr(self, name, val) - else: - for val, (name, field) in zip(args, fields_iter): - setattr(self, name, val) - kwargs.pop(name, None) - - for name, field in fields_iter: - try: - val = kwargs.pop( - name, field.default_value if isinstance(field, Field) else None - ) - setattr(self, name, val) - except KeyError: - pass - - if kwargs: - for prop in list(kwargs): - try: - if isinstance( - getattr(self.__class__, prop), property - ) or prop.startswith("_"): - setattr(self, prop, kwargs.pop(prop)) - except AttributeError: - pass - if kwargs: - raise TypeError( - f"'{list(kwargs)[0]}' is an invalid keyword argument" - f" for {self.__class__.__name__}" - ) diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index 25025e4d1..85a46867e 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -83,6 +83,10 @@ class MyObjectType(ObjectType): def test_generate_objecttype_with_private_attributes(): class MyObjectType(ObjectType): + def __init__(self, _private_state=None, **kwargs): + self._private_state = _private_state + super().__init__(**kwargs) + _private_state = None assert "_private_state" not in MyObjectType._meta.fields @@ -155,6 +159,20 @@ def test_objecttype_as_container_only_args(): assert container.field2 == "2" +def test_objecttype_repr(): + container = Container("1", "2") + assert repr(container) == "Container(field1='1', field2='2')" + + +def test_objecttype_eq(): + container1 = Container("1", "2") + container2 = Container("1", "2") + container3 = Container("2", "3") + assert container1 == container1 + assert container1 == container2 + assert container2 != container3 + + def test_objecttype_as_container_args_kwargs(): container = Container("1", field2="2") assert container.field1 == "1" @@ -173,17 +191,19 @@ def test_objecttype_as_container_all_kwargs(): def test_objecttype_as_container_extra_args(): - with raises(IndexError) as excinfo: + with raises(TypeError) as excinfo: Container("1", "2", "3") - assert "Number of args exceeds number of fields" == str(excinfo.value) + assert "__init__() takes from 1 to 3 positional arguments but 4 were given" == str( + excinfo.value + ) def test_objecttype_as_container_invalid_kwargs(): with raises(TypeError) as excinfo: Container(unexisting_field="3") - assert "'unexisting_field' is an invalid keyword argument for Container" == str( + assert "__init__() got an unexpected keyword argument 'unexisting_field'" == str( excinfo.value )