diff --git a/AutoRest/Generators/Python/Azure.Python/Templates/AzureServiceClientTemplate.cshtml b/AutoRest/Generators/Python/Azure.Python/Templates/AzureServiceClientTemplate.cshtml index 383db9e007..713d36e132 100644 --- a/AutoRest/Generators/Python/Azure.Python/Templates/AzureServiceClientTemplate.cshtml +++ b/AutoRest/Generators/Python/Azure.Python/Templates/AzureServiceClientTemplate.cshtml @@ -132,7 +132,7 @@ else { @:client_models = {} } - self._serialize = Serializer() + self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) @EmptyLine @foreach (var methodGroup in Model.MethodGroupModels) diff --git a/AutoRest/Generators/Python/Python/Templates/ServiceClientTemplate.cshtml b/AutoRest/Generators/Python/Python/Templates/ServiceClientTemplate.cshtml index e6ba317b1e..5a93943f28 100644 --- a/AutoRest/Generators/Python/Python/Templates/ServiceClientTemplate.cshtml +++ b/AutoRest/Generators/Python/Python/Templates/ServiceClientTemplate.cshtml @@ -125,7 +125,7 @@ else { @:client_models = {} } - self._serialize = Serializer() + self._serialize = Serializer(client_models) self._deserialize = Deserializer(client_models) @EmptyLine @foreach (var methodGroup in Model.MethodGroupModels) diff --git a/ClientRuntimes/Python/msrest/msrest/serialization.py b/ClientRuntimes/Python/msrest/msrest/serialization.py index a39ef1679b..c207aaa063 100644 --- a/ClientRuntimes/Python/msrest/msrest/serialization.py +++ b/ClientRuntimes/Python/msrest/msrest/serialization.py @@ -29,7 +29,6 @@ import datetime import decimal from enum import Enum -import importlib import json import logging import re @@ -95,7 +94,7 @@ def __init__(self, *args, **kwargs): def __eq__(self, other): """Compare objects by comparing all attributes.""" if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ + return self.__class__.__dict__ == other.__class__.__dict__ return False def __ne__(self, other): @@ -140,33 +139,37 @@ def _classify(cls, response, objects): raise TypeError("Object cannot be classified futher.") -def _convert_to_datatype(params, localtype): - """Convert a dict-like object to the given datatype - """ - return _recursive_convert_to_datatype( - params, - localtype.__name__, - importlib.import_module('..', localtype.__module__)) - - -def _recursive_convert_to_datatype(params, str_localtype, models_module): - """Convert a dict-like object to the given datatype - """ - if isinstance(params, list): - return [_recursive_convert_to_datatype( - data, - str_localtype[1:-1], - models_module) for data in params] - localtype = getattr(models_module, str_localtype, None) - if not localtype: - return params - result = { - key: _recursive_convert_to_datatype( - params[key], - localtype._attribute_map[key]['type'], - models_module) for key in params - } - return localtype(**result) +def _convert_to_datatype(data, data_type, localtypes): + if data is None: + return data + data_obj = localtypes.get(data_type.strip('{[]}')) + if data_obj: + if data_type.startswith('['): + data = [ + _convert_to_datatype( + param, data_type[1:-1], localtypes) for param in data + ] + elif data_type.startswith('{'): + data = { + key: _convert_to_datatype( + data[key], data_type[1:-1], localtypes) for key in data + } + elif not isinstance(data, data_obj): + result = { + key: _convert_to_datatype( + data[key], + data_obj._attribute_map[key]['type'], + localtypes) for key in data + } + data = data_obj(**result) + else: + try: + for attr, map in data._attribute_map.items(): + setattr(data, attr, _convert_to_datatype( + getattr(data, attr), map['type'], localtypes)) + except AttributeError: + pass + return data class Serializer(object): @@ -192,7 +195,7 @@ class Serializer(object): } flattten = re.compile(r"(? 9999 or utc.tm_year < 1: raise OverflowError("Hit max or min date") @@ -631,6 +628,9 @@ def serialize_iso(attr, **kwargs): except (ValueError, OverflowError) as err: msg = "Unable to serialize datetime object." raise_with_traceback(SerializationError, msg, err) + except AttributeError as err: + msg = "ISO-8601 object must be valid Datetime object." + raise_with_traceback(TypeError, msg, err) @staticmethod def serialize_unix(attr, **kwargs):