diff --git a/serializable/__init__.py b/serializable/__init__.py index a29f2cc..3b02cef 100644 --- a/serializable/__init__.py +++ b/serializable/__init__.py @@ -28,7 +28,20 @@ from io import StringIO, TextIOWrapper from json import JSONEncoder from sys import version_info -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) from xml.etree import ElementTree if version_info >= (3, 8): @@ -314,8 +327,10 @@ def _as_xml(self: _T, as_string: bool = True, element_name: Optional[str] = None nested_e = this_e for j in v: - if not prop_info.is_primitive_type(): + if not prop_info.is_primitive_type() and not prop_info.is_enum: nested_e.append(j.as_xml(as_string=False, element_name=nested_key, xmlns=xmlns)) + elif prop_info.is_enum: + ElementTree.SubElement(nested_e, nested_key).text = str(j.value) elif prop_info.concrete_type in (float, int): ElementTree.SubElement(nested_e, nested_key).text = str(j) elif prop_info.concrete_type is bool: @@ -494,6 +509,7 @@ class ObjectMetadataLibrary: _klass_property_names: Dict[str, Dict[SerializationType, str]] = {} _klass_property_types: Dict[str, Type[Any]] = {} _klass_property_xml_sequence: Dict[str, int] = {} + custom_enum_klasses: Set[Type[Any]] = set() klass_mappings: Dict[str, 'ObjectMetadataLibrary.SerializableClass'] = {} klass_property_mappings: Dict[str, Dict[str, 'ObjectMetadataLibrary.SerializableProperty']] = {} @@ -654,6 +670,14 @@ def _parse_type(self, type_: Any) -> None: if _oml_sc.name == results.get("array_of"): _k = _oml_sc.klass + if _k is None: + # Perhaps a custom ENUM? + if results.get("array_of") in ObjectMetadataLibrary.custom_enum_klasses: + prrint(f'CUSTOM ENUM') + for _enum_klass in ObjectMetadataLibrary.custom_enum_klasses: + if _enum_klass.__name__ == results.get("array_of"): + _k = _enum_klass + if _k is None: self._type_ = type_ # type: ignore self._deferred_type_parsing = True @@ -734,6 +758,10 @@ def is_klass_serializable(cls, klass: _T) -> bool: def is_property(cls, o: object) -> bool: return isinstance(o, property) + @classmethod + def register_enum(cls, klass: _T) -> _T: + cls.custom_enum_klasses.add(klass) + @classmethod def register_klass(cls, klass: _T, custom_name: Optional[str], serialization_types: Iterable[SerializationType], @@ -816,6 +844,21 @@ def register_property_type_mapping(cls, qual_name: str, mapped_type: Any) -> Non cls._klass_property_types.update({qual_name: mapped_type}) +def serializable_enum(cls: Optional[Type[_T]] = None) -> Union[Callable[[Any], Type[_T]], Type[_T]]: + + def wrap(kls: Type[_T]) -> Type[_T]: + ObjectMetadataLibrary.register_enum(klass=kls) + return kls + + # See if we're being called as @enum or @enum(). + if cls is None: + # We're called with parens. + return wrap + + # We're called as @register_klass without parens. + return wrap(cls) + + def serializable_class(cls: Optional[Type[_T]] = None, *, name: Optional[str] = None, serialization_types: Optional[Iterable[SerializationType]] = None, ignore_during_deserialization: Optional[Iterable[str]] = None