Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-generic protocols #436

Merged
merged 2 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
- _cattrs_ is now linted with [Ruff](https://beta.ruff.rs/docs/).
- Remove some unused lines in the unstructuring code.
([#416](https://github.com/python-attrs/cattrs/pull/416))
- Fix handling classes inheriting from non-generic protocols.
([#374](https://github.com/python-attrs/cattrs/issues/374))

## 23.1.2 (2023-06-02)

Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ clean-test: ## remove test and coverage artifacts

lint: ## check style with ruff and black
pdm run ruff src/ tests
pdm run isort -c src/ tests
pdm run black --check src tests docs/conf.py

test: ## run tests quickly with the default Python
Expand Down
13 changes: 9 additions & 4 deletions src/cattrs/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,11 @@ def is_counter(type):
)

def is_generic(obj) -> bool:
return isinstance(obj, (_GenericAlias, GenericAlias)) or is_subclass(
obj, Generic
"""Whether obj is a generic type."""
# Inheriting from protocol will inject `Generic` into the MRO
# without `__orig_bases__`.
return isinstance(obj, (_GenericAlias, GenericAlias)) or (
is_subclass(obj, Generic) and hasattr(obj, "__orig_bases__")
)

def copy_with(type, args):
Expand All @@ -343,7 +346,7 @@ def get_full_type_hints(obj, globalns=None, localns=None):
TupleSubscriptable = Tuple

from collections import Counter as ColCounter
from typing import Counter, TypedDict, Union, _GenericAlias
from typing import Counter, Generic, TypedDict, Union, _GenericAlias

from typing_extensions import Annotated, NotRequired, Required
from typing_extensions import get_origin as te_get_origin
Expand Down Expand Up @@ -429,7 +432,9 @@ def is_literal(type) -> bool:
return type.__class__ is _GenericAlias and type.__origin__ is Literal

def is_generic(obj):
return isinstance(obj, _GenericAlias)
return isinstance(obj, _GenericAlias) or (
is_subclass(obj, Generic) and hasattr(obj, "__orig_bases__")
)

def copy_with(type, args):
"""Replace a generic type's arguments."""
Expand Down
2 changes: 1 addition & 1 deletion src/cattrs/preconf/orjson.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Preconfigured converters for orjson."""
from base64 import b85decode, b85encode
from datetime import datetime, date
from datetime import date, datetime
from enum import Enum
from typing import Any, Type, TypeVar, Union

Expand Down
11 changes: 5 additions & 6 deletions tests/test_baseconverter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Test both structuring and unstructuring."""
from typing import Optional, Union

import attr
import pytest
from attr import define, fields, make_class
from attrs import define, fields, make_class
from hypothesis import HealthCheck, assume, given, settings
from hypothesis.strategies import just, one_of

Expand Down Expand Up @@ -90,9 +89,9 @@ def test_union_field_roundtrip(cl_and_vals_a, cl_and_vals_b, strat):
common_names = a_field_names & b_field_names
assume(len(a_field_names) > len(common_names))

@attr.s
@define
class C:
a = attr.ib(type=Union[cl_a, cl_b])
a: Union[cl_a, cl_b]

inst = C(a=cl_a(*vals_a, **kwargs_a))

Expand Down Expand Up @@ -161,9 +160,9 @@ def test_optional_field_roundtrip(cl_and_vals):
converter = BaseConverter()
cl, vals, kwargs = cl_and_vals

@attr.s
@define
class C:
a = attr.ib(type=Optional[cl])
a: Optional[cl]

inst = C(a=cl(*vals, **kwargs))
assert inst == converter.structure(converter.unstructure(inst), C)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import linecache
from traceback import format_exc

from attr import define
from attrs import define

from cattrs import Converter
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn
Expand Down
72 changes: 44 additions & 28 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Deque, Dict, Generic, List, Optional, TypeVar, Union

import pytest
from attr import asdict, attrs, define
from attrs import asdict, define

from cattrs import BaseConverter, Converter
from cattrs._compat import Protocol
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_able_to_structure_deeply_nested_generics_gen(converter):


def test_structure_unions_of_generics(converter):
@attrs(auto_attribs=True)
@define
class TClass2(Generic[T]):
c: T

Expand All @@ -142,7 +142,7 @@ class TClass2(Generic[T]):


def test_structure_list_of_generic_unions(converter):
@attrs(auto_attribs=True)
@define
class TClass2(Generic[T]):
c: T

Expand All @@ -154,7 +154,7 @@ class TClass2(Generic[T]):


def test_structure_deque_of_generic_unions(converter):
@attrs(auto_attribs=True)
@define
class TClass2(Generic[T]):
c: T

Expand All @@ -179,35 +179,31 @@ def test_raises_if_no_generic_params_supplied(
assert exc.value.type_ is T


def test_unstructure_generic_attrs():
c = Converter()

@attrs(auto_attribs=True)
def test_unstructure_generic_attrs(genconverter):
@define
class Inner(Generic[T]):
a: T

@attrs(auto_attribs=True)
@define
class Outer:
inner: Inner[int]

initial = Outer(Inner(1))
raw = c.unstructure(initial)
raw = genconverter.unstructure(initial)

assert raw == {"inner": {"a": 1}}

new = c.structure(raw, Outer)
new = genconverter.structure(raw, Outer)
assert initial == new

@attrs(auto_attribs=True)
@define
class OuterStr:
inner: Inner[str]

assert c.structure(raw, OuterStr) == OuterStr(Inner("1"))

assert genconverter.structure(raw, OuterStr) == OuterStr(Inner("1"))

def test_unstructure_deeply_nested_generics():
c = Converter()

def test_unstructure_deeply_nested_generics(genconverter):
@define
class Inner:
a: int
Expand All @@ -217,16 +213,14 @@ class Outer(Generic[T]):
inner: T

initial = Outer[Inner](Inner(1))
raw = c.unstructure(initial, Outer[Inner])
raw = genconverter.unstructure(initial, Outer[Inner])
assert raw == {"inner": {"a": 1}}

raw = c.unstructure(initial)
raw = genconverter.unstructure(initial)
assert raw == {"inner": {"a": 1}}


def test_unstructure_deeply_nested_generics_list():
c = Converter()

def test_unstructure_deeply_nested_generics_list(genconverter):
@define
class Inner:
a: int
Expand All @@ -236,16 +230,14 @@ class Outer(Generic[T]):
inner: List[T]

initial = Outer[Inner]([Inner(1)])
raw = c.unstructure(initial, Outer[Inner])
raw = genconverter.unstructure(initial, Outer[Inner])
assert raw == {"inner": [{"a": 1}]}

raw = c.unstructure(initial)
raw = genconverter.unstructure(initial)
assert raw == {"inner": [{"a": 1}]}


def test_unstructure_protocol():
c = Converter()

def test_unstructure_protocol(genconverter):
class Proto(Protocol):
a: int

Expand All @@ -258,10 +250,10 @@ class Outer:
inner: Proto

initial = Outer(Inner(1))
raw = c.unstructure(initial, Outer)
raw = genconverter.unstructure(initial, Outer)
assert raw == {"inner": {"a": 1}}

raw = c.unstructure(initial)
raw = genconverter.unstructure(initial)
assert raw == {"inner": {"a": 1}}


Expand Down Expand Up @@ -306,3 +298,27 @@ class B(A[int]):
pass

assert generate_mapping(B, {}) == {T.__name__: int}


def test_nongeneric_protocols(converter):
"""Non-generic protocols work."""

class NongenericProtocol(Protocol):
...

@define
class Entity(NongenericProtocol):
...

assert generate_mapping(Entity) == {}

class GenericProtocol(Protocol[T]):
...

@define
class GenericEntity(GenericProtocol[int]):
a: int

assert generate_mapping(GenericEntity) == {"T": int}

assert converter.structure({"a": 1}, GenericEntity) == GenericEntity(1)
2 changes: 1 addition & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Tests for the extended validation mode."""
import pickle
from typing import Dict, FrozenSet, List, Set, Tuple

import pytest
import pickle
from attrs import define, field
from attrs.validators import in_
from hypothesis import given
Expand Down