Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishka17 committed Aug 12, 2019
2 parents 78979c1 + c036b91 commit c6d0ac8
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 30 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ serialized = factory.dump(book)
* [Schemas](#Schemas)
* [Common schemas](#common-schemas)
* [Name styles](#name-styles)
* [Generic classes](#generic-classes)
* [Structure flattening](#structure-flattening)
* [Additional steps](#additional-steps)
* [Schema inheritance](#schema-inheritance)
Expand Down Expand Up @@ -223,6 +224,32 @@ Following name styles are supported:
* `camel_snake` (Camel_Snake)
* `dot` (dot.case)

#### Generic classes

It is possible to dump and load instances of generic dataclasses with.
You can set schema for generic or concrete types with one limitation:
It is not possible to detect concrete type of dataclass when dumping. So if you need to have different schemas for different concrete types you should exclipitly set them when dumping your data.

```python
T = TypeVar("T")


@dataclass
class FakeFoo(Generic[T]):
value: T


factory = Factory(schemas={
FakeFoo[str]: Schema(name_mapping={"value": "s"}),
FakeFoo: Schema(name_mapping={"value": "i"}),
})
data = {"i": 42, "s": "Hello"}
assert factory.load(data, FakeFoo[str]) == FakeFoo("Hello")
assert factory.load(data, FakeFoo[int]) == FakeFoo(42)
assert factory.dump(FakeFoo("hello"), FakeFoo[str]) == {"s": "hello"} # concrete type is set explicitly
assert factory.dump(FakeFoo("hello")) == {"i": "hello"} # generic type is detected automatically
```

#### Structure flattening

Since version 2.2 you can flatten hierarchy of data when parsing.
Expand Down
11 changes: 9 additions & 2 deletions dataclass_factory/factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import astuple
from typing import Dict, Type, Any, Optional, TypeVar

from .common import Serializer, Parser
from .parsers import create_parser, get_lazy_parser
from .schema import Schema, merge_schema
from .serializers import create_serializer, get_lazy_serializer
from .type_detection import is_generic_concrete

DEFAULT_SCHEMA = Schema[Any](
trim_trailing_underscore=True,
Expand Down Expand Up @@ -59,9 +59,16 @@ def __init__(self,
})

def schema(self, class_: Type[T]) -> Schema[T]:
if is_generic_concrete(class_):
base_class = class_.__origin__ # type: ignore
else:
base_class = None

schema = self.schemas.get(class_)
if not schema:
schema = merge_schema(None, self.default_schema)
if base_class:
schema = self.schemas.get(base_class)
schema = merge_schema(schema, self.default_schema)
self.schemas[class_] = schema
return schema

Expand Down
7 changes: 4 additions & 3 deletions dataclass_factory/parsers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import decimal
import inspect
import itertools
from collections import deque
from dataclasses import fields, is_dataclass

import itertools
from typing import (
List, Set, FrozenSet, Deque, Any, Callable,
Dict, Collection, Type, get_type_hints,
Expand All @@ -15,7 +16,7 @@
from .type_detection import (
is_tuple, is_collection, is_any, hasargs, is_optional,
is_none, is_union, is_dict, is_enum,
is_generic, fill_type_args, args_unspecified,
is_generic_concrete, fill_type_args, args_unspecified,
)


Expand Down Expand Up @@ -253,7 +254,7 @@ def create_parser_impl(factory, schema: Schema, debug_path: bool, cls: Type) ->
return get_collection_parser(collection_factory, item_parser, debug_path)
if is_union(cls):
return get_union_parser(tuple(factory.parser(x) for x in cls.__args__))
if is_generic(cls) and is_dataclass(cls.__origin__):
if is_generic_concrete(cls) and is_dataclass(cls.__origin__):
args = dict(zip(cls.__origin__.__parameters__, cls.__args__))
parsers = {
field.name: factory.parser(fill_type_args(args, field.type))
Expand Down
4 changes: 2 additions & 2 deletions dataclass_factory/schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from copy import copy
from typing import List, Dict, Callable, Tuple, Type, Sequence, Optional, Generic, Union

from dataclasses import fields

from typing import List, Dict, Callable, Tuple, Type, Sequence, Optional, Generic, Union

from .common import Serializer, Parser, T, InnerConverter
from .naming import NameStyle, NAMING_FUNC
from .path_utils import Path
Expand Down
26 changes: 14 additions & 12 deletions dataclass_factory/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
# -*- coding: utf-8 -*-
from dataclasses import is_dataclass, fields
from marshal import loads, dumps

from typing import Any, Type, get_type_hints, List, Dict, Optional, Union

from .common import Serializer, T
from .path_utils import init_structure, Path
from .schema import Schema, get_dataclass_fields
from .type_detection import (
is_collection, is_tuple, hasargs, is_dict, is_optional,
is_union, is_any, is_generic, is_type_var,
is_union, is_any, is_generic_concrete, is_type_var,
fill_type_args,
)


Expand Down Expand Up @@ -105,6 +107,17 @@ def serializer_with_steps(data):
def create_serializer_impl(factory, schema: Schema, debug_path: bool, class_: Type) -> Serializer:
if is_type_var(class_):
return get_lazy_serializer(factory)
if is_generic_concrete(class_) and is_dataclass(class_.__origin__):
args = dict(zip(class_.__origin__.__parameters__, class_.__args__))
serializers = {
field.name: factory.serializer(fill_type_args(args, field.type))
for field in fields(class_.__origin__)
}
return get_dataclass_serializer(
class_.__origin__,
serializers,
schema,
)
if is_dataclass(class_):
resolved_hints = get_type_hints(class_)
return get_dataclass_serializer(
Expand Down Expand Up @@ -143,16 +156,5 @@ def create_serializer_impl(factory, schema: Schema, debug_path: bool, class_: Ty
if is_collection(class_):
item_serializer = factory.serializer(class_.__args__[0] if class_.__args__ else Any)
return get_collection_serializer(item_serializer)
if is_generic(class_) and is_dataclass(class_.__origin__):
args = dict(zip(class_.__origin__.__parameters__, class_.__args__))
serializers = {
field.name: factory.serializer(args.get(field.type, field.type))
for field in fields(class_.__origin__)
}
return get_dataclass_serializer(
class_.__origin__,
serializers,
schema,
)
else:
return stub_serializer
10 changes: 7 additions & 3 deletions dataclass_factory/type_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from enum import Enum

from typing import Collection, Tuple, Optional, Any, Dict, Union, Type, TypeVar
from typing import Collection, Tuple, Optional, Any, Dict, Union, Type, TypeVar, Generic


def hasargs(type_, *args) -> bool:
Expand Down Expand Up @@ -55,10 +55,14 @@ def is_any(type_: Type) -> bool:
return type_ in (Any, inspect.Parameter.empty)


def is_generic(type_: Type) -> bool:
def is_generic_concrete(type_: Type) -> bool:
return hasattr(type_, "__origin__")


def is_generic(type_: Type) -> bool:
return issubclass_safe(type_, Generic)


def is_none(type_: Type) -> bool:
return type_ is type(None)

Expand Down Expand Up @@ -88,7 +92,7 @@ def is_type_var(type_: Type) -> bool:

def fill_type_args(args: Dict[Type, Type], type_: Type) -> Type:
type_ = args.get(type_, type_)
if is_generic(type_):
if is_generic_concrete(type_):
type_args = tuple(
args.get(a, a) for a in type_.__args__
)
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
author_email='17@itishka.org',
license='Apache2',
classifiers=[
'Development Status :: 3 - Alpha',
'Operating System :: OS Independent',
'Intended Audience :: Developers',
'License :: OSI Approved :: Apache Software License',
Expand Down
3 changes: 1 addition & 2 deletions tests/test_custom_type_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from unittest import TestCase

from dataclasses import dataclass, asdict
from unittest import TestCase

from dataclass_factory import parse, dict_factory

Expand Down
3 changes: 2 additions & 1 deletion tests/test_dataclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# -*- coding: utf-8 -*-
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, List, Union, Tuple, Dict, FrozenSet, Set, Any
from unittest import TestCase

from typing import Optional, List, Union, Tuple, Dict, FrozenSet, Set, Any

from dataclass_factory import parse


Expand Down
31 changes: 30 additions & 1 deletion tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import TypeVar, Generic

from dataclass_factory import Factory
from dataclass_factory import Factory, Schema

T = TypeVar('T')
V = TypeVar('V')
Expand All @@ -16,6 +16,11 @@ class Foo(Generic[T]):
value: T


@dataclass
class FakeFoo(Generic[T]):
value: str


@dataclass
class FooBar(Generic[T, V]):
value: T
Expand Down Expand Up @@ -68,3 +73,27 @@ def test_inner2(self):
self.assertEqual(self.factory.load(baz_serial, Foo[FooBaz[int]]), baz)
self.assertEqual(self.factory.dump(baz, Foo[FooBaz[int]]), baz_serial)
self.assertEqual(self.factory.dump(baz), baz_serial)

def test_schema_load(self):
factory = Factory(schemas={
FakeFoo[str]: Schema(name_mapping={"value": "s"}),
FakeFoo: Schema(name_mapping={"value": "v"}),
})
data = {"v": "hello", "i": 42, "s": "SSS"}
self.assertEqual(factory.load(data, FakeFoo[str]), FakeFoo("SSS"))
self.assertEqual(factory.load(data, FakeFoo[int]), FakeFoo("hello"))

def test_schema_dump(self):
factory = Factory(schemas={
FakeFoo[str]: Schema(name_mapping={"value": "s"}),
FakeFoo: Schema(name_mapping={"value": "v"}),
})
# self.assertEqual(factory.dump(FakeFoo("hello"), FakeFoo[str]), {"s": "hello"})
self.assertEqual(factory.dump(FakeFoo("hello")), {"v": "hello"})

def test_schema_dump_inner(self):
factory = Factory(schemas={
FooBaz[int]: Schema(name_mapping={"foo": "bar"}),
Foo[int]: Schema(name_mapping={"value": "v"})
})
self.assertEqual(factory.dump(FooBaz(Foo(1)), FooBaz[int]), {"bar": {"v": 1}})
2 changes: 1 addition & 1 deletion tests/test_invalid_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_should_raise_when_invalid_int_field_provided(self):
ParserFactory(debug_path=True).get_parser(Foo)({"a": "20x", "b": 20})
self.assertTrue(False, "ValueError exception expected")
except InvalidFieldError as exc:
self.assertEqual(['a',], exc.field_path)
self.assertEqual(['a', ], exc.field_path)

def test_should_provide_failed_key_hierarchy_when_invalid_nested_data_parsed(self):
try:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pre_post.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Any
from dataclasses import dataclass
from unittest import TestCase

from dataclasses import dataclass
from typing import List

from dataclass_factory import Factory, Schema

Expand Down

0 comments on commit c6d0ac8

Please sign in to comment.