From 05723b3dfdfcecbaa248e41242ed1bcf3869cd6d Mon Sep 17 00:00:00 2001 From: foodszhang Date: Wed, 29 Aug 2018 11:23:22 +0800 Subject: [PATCH 1/2] add --- swagger_py_codegen/parser.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/swagger_py_codegen/parser.py b/swagger_py_codegen/parser.py index 7d57b07..93bc1b3 100644 --- a/swagger_py_codegen/parser.py +++ b/swagger_py_codegen/parser.py @@ -19,20 +19,15 @@ def __init__(self, data, ref): self.ref = ref self._data = data + def __getitem__(self, key): return self._data.__getitem__(key) - def __settiem__(self, key, value): - return self._data.__settiem__(key, value) - - def get(self, key, default=None): - return self._data.get(key, default) - - def has(self, key, default=None): - return self._data.has(key) + def __setitem__(self, key, value): + return self._data.__setitem__(key, value) - def keys(self): - return self._data.keys() + def __getattr__(self, key): + return self._data.__getattribute__(key) def __iter__(self): return self._data.__iter__() @@ -46,6 +41,8 @@ def __eq__(self, other): else: return object.__eq__(other) + def copy(self): + return RefNode(self._data, self.ref) class Swagger(object): From 9972ae54123f365baef25409395077c5ff702fdb Mon Sep 17 00:00:00 2001 From: foodszhang Date: Wed, 29 Aug 2018 11:49:57 +0800 Subject: [PATCH 2/2] Revert "Revert "resolver"" This reverts commit c309463b239c015b85e41a8c69597106754d8c88. Conflicts: swagger_py_codegen/parser.py --- swagger_py_codegen/jsonschema.py | 30 ++++++++++++------- .../templates/falcon/validators.tpl | 12 ++++---- .../templates/flask/validators.tpl | 10 +++---- .../templates/jsonschema/schemas.tpl | 8 +++-- .../templates/sanic/validators.tpl | 10 +++---- .../templates/tornado/validators.tpl | 12 ++++---- tests/test_jsonschema.py | 24 +++++++-------- tests/test_parser.py | 21 ++++++------- 8 files changed, 68 insertions(+), 59 deletions(-) diff --git a/swagger_py_codegen/jsonschema.py b/swagger_py_codegen/jsonschema.py index 8a467cf..5ff6da7 100644 --- a/swagger_py_codegen/jsonschema.py +++ b/swagger_py_codegen/jsonschema.py @@ -5,7 +5,7 @@ from inspect import getsource from .base import Code, CodeGenerator -from .parser import schema_var_name +from .parser import RefNode class Schema(Code): @@ -89,10 +89,8 @@ def build_data(swagger): scopes[(endpoint, method)] = list(security.values()).pop() break - schemas = OrderedDict([(schema_var_name(path), swagger.get(path)) for path in swagger.definitions]) - data = dict( - schemas=schemas, + definitions={'definitions':swagger.origin_data.get('definitions', {})}, validators=validators, filters=filters, scopes=scopes, @@ -109,7 +107,7 @@ def _process(self): yield Schema(build_data(self.swagger)) -def merge_default(schema, value, get_first=True): +def merge_default(schema, value, get_first=True, resolver=None): # TODO: more types support type_defaults = { 'integer': 9573, @@ -119,17 +117,17 @@ def merge_default(schema, value, get_first=True): 'boolean': False } - results = normalize(schema, value, type_defaults) + results = normalize(schema, value, type_defaults, resolver=resolver) if get_first: return results[0] return results -def build_default(schema): - return merge_default(schema, None) +def build_default(schema, resolver=None): + return merge_default(schema, None, resolver=resolver) -def normalize(schema, data, required_defaults=None): +def normalize(schema, data, required_defaults=None, resolver=None): if required_defaults is None: required_defaults = {} errors = [] @@ -217,7 +215,7 @@ def _normalize_dict(schema, data): def _normalize_list(schema, data): result = [] - if hasattr(data, '__iter__') and not isinstance(data, dict): + if hasattr(data, '__iter__') and not isinstance(data, (dict, RefNode)): for item in data: result.append(_normalize(schema.get('items'), item)) elif 'default' in schema: @@ -230,6 +228,15 @@ def _normalize_default(schema, data): else: return data + def _normalize_ref(schema, data): + if resolver == None: + raise TypeError("resolver must be provided") + ref = schema.get(u"$ref") + scope, resolved = resolver.resolve(ref) + return _normalize(resolved, data) + + + def _normalize(schema, data): if schema is True or schema == {}: return data @@ -239,10 +246,13 @@ def _normalize(schema, data): 'object': _normalize_dict, 'array': _normalize_list, 'default': _normalize_default, + 'ref': _normalize_ref } type_ = schema.get('type', 'object') if type_ not in funcs: type_ = 'default' + if schema.get(u'$ref', None): + type_ = 'ref' return funcs[type_](schema, data) diff --git a/swagger_py_codegen/templates/falcon/validators.tpl b/swagger_py_codegen/templates/falcon/validators.tpl index 0ed7f3c..7114232 100644 --- a/swagger_py_codegen/templates/falcon/validators.tpl +++ b/swagger_py_codegen/templates/falcon/validators.tpl @@ -14,7 +14,7 @@ from werkzeug.datastructures import MultiDict, Headers from jsonschema import Draft4Validator from .schemas import ( - validators, filters, scopes, security, base_path, normalize) + validators, filters, scopes, resolver, security, base_path, normalize) if six.PY3: @@ -44,7 +44,7 @@ class JSONEncoder(json.JSONEncoder): class FalconValidatorAdaptor(object): def __init__(self, schema): - self.validator = Draft4Validator(schema) + self.validator = Draft4Validator(schema, resolver=resolver) def validate_number(self, type_, value): try: @@ -87,7 +87,7 @@ class FalconValidatorAdaptor(object): def validate(self, value): value = self.type_convert(value) errors = {e.path[0]: e.message for e in self.validator.iter_errors(value)} - return normalize(self.validator.schema, value)[0], errors + return normalize(self.validator.schema, value, resolver=resolver)[0], errors def request_validate(req, resp, resource, params): @@ -154,10 +154,10 @@ def response_filter(req, resp, resource): 'Not defined', description='`%d` is not a defined status code.' % status) - _resp, errors = normalize(schemas['schema'], req.context['result']) + _resp, errors = normalize(schemas['schema'], req.context['result'], resolver=resolver) if schemas['headers']: headers, header_errors = normalize( - {'properties': schemas['headers']}, headers) + {'properties': schemas['headers']}, headers, resolver=resolver) errors.extend(header_errors) if errors: raise falcon.HTTPInternalServerError(title='Expectation Failed', @@ -165,4 +165,4 @@ def response_filter(req, resp, resource): if 'result' not in req.context: return - resp.body = json.dumps(_resp) \ No newline at end of file + resp.body = json.dumps(_resp) diff --git a/swagger_py_codegen/templates/flask/validators.tpl b/swagger_py_codegen/templates/flask/validators.tpl index c642ddb..a604e47 100644 --- a/swagger_py_codegen/templates/flask/validators.tpl +++ b/swagger_py_codegen/templates/flask/validators.tpl @@ -15,7 +15,7 @@ from flask_restful.utils import unpack from jsonschema import Draft4Validator from .schemas import ( - validators, filters, scopes, security, merge_default, normalize) + validators, filters, scopes, resolver, security, merge_default, normalize) class JSONEncoder(json.JSONEncoder): @@ -29,7 +29,7 @@ class JSONEncoder(json.JSONEncoder): class FlaskValidatorAdaptor(object): def __init__(self, schema): - self.validator = Draft4Validator(schema) + self.validator = Draft4Validator(schema, resolver=resolver) def validate_number(self, type_, value): try: @@ -72,7 +72,7 @@ class FlaskValidatorAdaptor(object): def validate(self, value): value = self.type_convert(value) errors = list(e.message for e in self.validator.iter_errors(value)) - return normalize(self.validator.schema, value)[0], errors + return normalize(self.validator.schema, value, resolver=resolver)[0], errors def request_validate(view): @@ -136,10 +136,10 @@ def response_filter(view): # return resp, status, headers abort(500, message='`%d` is not a defined status code.' % status) - resp, errors = normalize(schemas['schema'], resp) + resp, errors = normalize(schemas['schema'], resp, resolver=resolver) if schemas['headers']: headers, header_errors = normalize( - {'properties': schemas['headers']}, headers) + {'properties': schemas['headers']}, headers, resolver=resolver) errors.extend(header_errors) if errors: abort(500, message='Expectation Failed', errors=errors) diff --git a/swagger_py_codegen/templates/jsonschema/schemas.tpl b/swagger_py_codegen/templates/jsonschema/schemas.tpl index 097fb4c..df5ac34 100644 --- a/swagger_py_codegen/templates/jsonschema/schemas.tpl +++ b/swagger_py_codegen/templates/jsonschema/schemas.tpl @@ -1,16 +1,17 @@ # -*- coding: utf-8 -*- import six +from jsonschema import RefResolver +from swagger_py_codegen.parser import RefNode # TODO: datetime support + {% include '_do_not_change.tpl' %} base_path = '{{base_path}}' -{% for name, value in schemas.items() %} -{{ name }} = {{ value }} -{%- endfor %} +definitions = {{ definitions }} validators = { {%- for name, value in validators.items() %} @@ -30,6 +31,7 @@ scopes = { {%- endfor %} } +resolver = RefResolver.from_schema(definitions) class Security(object): diff --git a/swagger_py_codegen/templates/sanic/validators.tpl b/swagger_py_codegen/templates/sanic/validators.tpl index 1491aa5..c5edc4d 100644 --- a/swagger_py_codegen/templates/sanic/validators.tpl +++ b/swagger_py_codegen/templates/sanic/validators.tpl @@ -17,7 +17,7 @@ from sanic.request import RequestParameters from jsonschema import Draft4Validator from .schemas import ( - validators, filters, scopes, security, base_path, normalize, current) + validators, filters, scopes, security, resolver, base_path, normalize, current) def unpack(value): @@ -63,7 +63,7 @@ class JSONEncoder(json.JSONEncoder): class SanicValidatorAdaptor(object): def __init__(self, schema): - self.validator = Draft4Validator(schema) + self.validator = Draft4Validator(schema, resolver=resolver) def validate_number(self, type_, value): try: @@ -106,7 +106,7 @@ class SanicValidatorAdaptor(object): def validate(self, value): value = self.type_convert(value) errors = list(e.message for e in self.validator.iter_errors(value)) - return normalize(self.validator.schema, value)[0], errors + return normalize(self.validator.schema, value, resolver=resolver)[0], errors def request_validate(view): @@ -175,10 +175,10 @@ def response_filter(view): # return resp, status, headers raise ServerError('`%d` is not a defined status code.' % status, 500) - resp, errors = normalize(schemas['schema'], resp) + resp, errors = normalize(schemas['schema'], resp, resolver=resolver) if schemas['headers']: headers, header_errors = normalize( - {'properties': schemas['headers']}, headers) + {'properties': schemas['headers']}, headers, resolver=resolver) errors.extend(header_errors) if errors: raise ServerError('Expectation Failed', 500) diff --git a/swagger_py_codegen/templates/tornado/validators.tpl b/swagger_py_codegen/templates/tornado/validators.tpl index fdec55d..08b6592 100644 --- a/swagger_py_codegen/templates/tornado/validators.tpl +++ b/swagger_py_codegen/templates/tornado/validators.tpl @@ -11,13 +11,13 @@ import six from functools import wraps from jsonschema import Draft4Validator -from .schemas import validators, scopes, normalize, filters +from .schemas import validators, scopes, resolver, normalize, filters class ValidatorAdaptor(object): def __init__(self, schema): - self.validator = Draft4Validator(schema) + self.validator = Draft4Validator(schema, resolver=resolver) def validate_number(self, type_, value): try: @@ -66,7 +66,7 @@ class ValidatorAdaptor(object): def validate(self, value): value = self.type_convert(value) errors = list(e.message for e in self.validator.iter_errors(value)) - return normalize(self.validator.schema, value)[0], errors + return normalize(self.validator.schema, value, resolver=resolver)[0], errors def request_validate(obj): def _request_validate(view): @@ -134,10 +134,10 @@ def response_filter(obj): raise tornado.web.HTTPError( 500, message='`%d` is not a defined status code.' % status) - resp, errors = normalize(schemas['schema'], resp) + resp, errors = normalize(schemas['schema'], resp, resolver=resolver) if schemas['headers']: headers, header_errors = normalize( - {'properties': schemas['headers']}, headers) + {'properties': schemas['headers']}, headers, resolver=resolver) errors.extend(header_errors) if errors: raise tornado.web.HTTPError( @@ -167,4 +167,4 @@ def unpack(value): except ValueError: pass - return value, 200, {} \ No newline at end of file + return value, 200, {} diff --git a/tests/test_jsonschema.py b/tests/test_jsonschema.py index 863064a..d633802 100644 --- a/tests/test_jsonschema.py +++ b/tests/test_jsonschema.py @@ -1,5 +1,5 @@ from __future__ import absolute_import -from swagger_py_codegen.parser import Swagger +from swagger_py_codegen.parser import Swagger, RefNode from swagger_py_codegen.jsonschema import build_data @@ -7,7 +7,7 @@ def test_schema_base_01(): data = {} swagger = Swagger(data) data = build_data(swagger) - assert len(data['schemas']) == 0 + assert len(data['definitions']['definitions']) == 0 def test_schema_base_02(): @@ -23,7 +23,7 @@ def test_schema_base_02(): } swagger = Swagger(data) data = build_data(swagger) - assert len(data['schemas']) == 1 + assert len(data['definitions']['definitions']) == 1 def test_schema_base_03(): @@ -41,7 +41,7 @@ def test_schema_base_03(): } swagger = Swagger(data) data = build_data(swagger) - assert len(data['schemas']) == 0 + assert len(data['definitions']['definitions']) == 0 def test_schema_ref_01(): @@ -66,8 +66,8 @@ def test_schema_ref_01(): } swagger = Swagger(data) data = build_data(swagger) - assert len(data['schemas']) == 2 - assert list(data['schemas'].keys())[0] == 'DefinitionsUser' + print("!!!!!!!!!!!!!!!!!!", data['definitions']['definitions']) + assert len(data['definitions']['definitions']) == 2 def test_validators(): @@ -116,13 +116,13 @@ def test_validators(): } swagger = Swagger(data) data = build_data(swagger) - schemas = data['schemas'] + schemas = data['definitions']['definitions'] validators = data['validators'] # body parameters assert ('/users', 'POST') in validators v1 = validators[('/users', 'POST')]['body'] - assert v1 == schemas['DefinitionsUser'] + assert(v1 == RefNode(schemas['User'], '#/definitions/User')) # query parameters v2 = validators[('/users', 'POST')]['query'] @@ -136,8 +136,8 @@ def test_validators(): assert 'path' not in validators[('/users', 'POST')] # definitions - assert 'DefinitionsUser' in schemas - assert 'DefinitionsProduct' in schemas + assert 'User' in schemas + assert 'Product' in schemas assert len(schemas) == 2 @@ -192,7 +192,7 @@ def test_filters(): } swagger = Swagger(data) data = build_data(swagger) - schemas = data['schemas'] + definitions = data['definitions']['definitions'] filters = data['filters'] assert 201 in filters[('/users', 'POST')] @@ -201,7 +201,7 @@ def test_filters(): r1 = filters[('/users', 'POST')][201] r2 = filters[('/users', 'POST')][422] - assert r1['schema'] == schemas['DefinitionsUser'] + assert r1['schema'] == RefNode(definitions['User'], '#/definitions/User') assert r2['schema']['properties']['code'] == {'type': 'string'} diff --git a/tests/test_parser.py b/tests/test_parser.py index dcec8b4..d2fa8c3 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,5 +1,4 @@ from __future__ import absolute_import -import pytest from swagger_py_codegen.parser import Swagger @@ -24,8 +23,8 @@ def test_swagger_ref_count_01(): } } swagger = Swagger(data) - assert swagger.definitions[0] == ('definitions', 'User') - assert swagger.definitions[1] == ('definitions', 'Product') + assert swagger.definitions[0] == ('definitions', 'Product') + assert swagger.definitions[1] == ('definitions', 'User') def test_swagger_ref_count_02(): @@ -65,9 +64,9 @@ def test_swagger_ref_count_02(): } } swagger = Swagger(data) - assert swagger.definitions[0] == ('definitions', 'User') + assert swagger.definitions[0] == ('definitions', 'Order') assert swagger.definitions[1] == ('definitions', 'Product') - assert swagger.definitions[2] == ('definitions', 'Order') + assert swagger.definitions[2] == ('definitions', 'User') def test_swagger_ref_count_03(): @@ -114,10 +113,10 @@ def test_swagger_ref_count_03(): } } swagger = Swagger(data) - assert swagger.definitions[0] == ('definitions', 'User') - assert swagger.definitions[1] == ('definitions', 'Product') - assert swagger.definitions[2] == ('definitions', 'Order') - assert swagger.definitions[3] == ('definitions', 'OrderList') + assert swagger.definitions[0] == ('definitions', 'Order') + assert swagger.definitions[1] == ('definitions', 'OrderList') + assert swagger.definitions[2] == ('definitions', 'Product') + assert swagger.definitions[3] == ('definitions', 'User') def test_swagger_ref_count_04(): @@ -162,9 +161,7 @@ def test_swagger_ref_count_04(): } } } - with pytest.raises(ValueError) as excinfo: - Swagger(data) - assert excinfo.type == ValueError + Swagger(data) def test_swagger_ref_node():