Skip to content

Commit

Permalink
Merge pull request #120 from foodszhang/ref_resolver
Browse files Browse the repository at this point in the history
ref resolver
  • Loading branch information
gusibi authored Sep 7, 2018
2 parents ab8e170 + 9972ae5 commit 278b546
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 73 deletions.
30 changes: 20 additions & 10 deletions swagger_py_codegen/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import getsource

from .base import Code, CodeGenerator
from .parser import schema_var_name
from .parser import RefNode


class Schema(Code):
Expand Down Expand Up @@ -89,10 +89,8 @@ def build_data(swagger):
scopes[(endpoint, method)] = list(security.values()).pop()
break

schemas = OrderedDict([(schema_var_name(path), swagger.get(path)) for path in swagger.definitions])

data = dict(
schemas=schemas,
definitions={'definitions':swagger.origin_data.get('definitions', {})},
validators=validators,
filters=filters,
scopes=scopes,
Expand All @@ -109,7 +107,7 @@ def _process(self):
yield Schema(build_data(self.swagger))


def merge_default(schema, value, get_first=True):
def merge_default(schema, value, get_first=True, resolver=None):
# TODO: more types support
type_defaults = {
'integer': 9573,
Expand All @@ -119,17 +117,17 @@ def merge_default(schema, value, get_first=True):
'boolean': False
}

results = normalize(schema, value, type_defaults)
results = normalize(schema, value, type_defaults, resolver=resolver)
if get_first:
return results[0]
return results


def build_default(schema):
return merge_default(schema, None)
def build_default(schema, resolver=None):
return merge_default(schema, None, resolver=resolver)


def normalize(schema, data, required_defaults=None):
def normalize(schema, data, required_defaults=None, resolver=None):
if required_defaults is None:
required_defaults = {}
errors = []
Expand Down Expand Up @@ -217,7 +215,7 @@ def _normalize_dict(schema, data):

def _normalize_list(schema, data):
result = []
if hasattr(data, '__iter__') and not isinstance(data, dict):
if hasattr(data, '__iter__') and not isinstance(data, (dict, RefNode)):
for item in data:
result.append(_normalize(schema.get('items'), item))
elif 'default' in schema:
Expand All @@ -230,6 +228,15 @@ def _normalize_default(schema, data):
else:
return data

def _normalize_ref(schema, data):
if resolver == None:
raise TypeError("resolver must be provided")
ref = schema.get(u"$ref")
scope, resolved = resolver.resolve(ref)
return _normalize(resolved, data)



def _normalize(schema, data):
if schema is True or schema == {}:
return data
Expand All @@ -239,10 +246,13 @@ def _normalize(schema, data):
'object': _normalize_dict,
'array': _normalize_list,
'default': _normalize_default,
'ref': _normalize_ref
}
type_ = schema.get('type', 'object')
if type_ not in funcs:
type_ = 'default'
if schema.get(u'$ref', None):
type_ = 'ref'

return funcs[type_](schema, data)

Expand Down
47 changes: 33 additions & 14 deletions swagger_py_codegen/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,36 @@ def schema_var_name(path):
return ''.join(map(str.capitalize, map(str, path)))


class RefNode(dict):
class RefNode(object):

def __init__(self, data, ref):
self.ref = ref
super(RefNode, self).__init__(data)
self._data = data


def __getitem__(self, key):
return self._data.__getitem__(key)

def __setitem__(self, key, value):
return self._data.__setitem__(key, value)

def __getattr__(self, key):
return self._data.__getattribute__(key)

def __iter__(self):
return self._data.__iter__()

def __repr__(self):
return schema_var_name(self.ref)
return repr({'$ref':self.ref})

def __eq__(self, other):
if isinstance(other, RefNode):
return self._data == other._data and self.ref == other.ref
else:
return object.__eq__(other)

def copy(self):
return RefNode(self._data, self.ref)

class Swagger(object):

Expand All @@ -40,14 +61,10 @@ def _process_ref(self):
"""
resolve all references util no reference exists
"""
while 1:
li = list(self.search(['**', '$ref']))
if not li:
break
for path, ref in li:
data = resolve(self.data, ref)
path = path[:-1]
self.set(path, data)
for path, ref in self.search(['**', '$ref']):
data = resolve(self.data, ref)
path = path[:-1]
self.set(path, RefNode(data, ref))

def _resolve_definitions(self):
"""
Expand Down Expand Up @@ -76,17 +93,19 @@ def get_definition_refs():
while definition_refs:
ready = {
definition for definition, refs
in six.iteritems(definition_refs) if not refs
in six.iteritems(definition_refs)
}
if not ready:
msg = '$ref circular references found!\n'
raise ValueError(msg)
continue
#msg = '$ref circular references found!\n'
#raise ValueError(msg)
for definition in ready:
del definition_refs[definition]
for refs in six.itervalues(definition_refs):
refs.difference_update(ready)

self._definitions += ready
self._definitions.sort(key=lambda x :x[1])

def search(self, path):
for p, d in dpath.util.search(
Expand Down
12 changes: 6 additions & 6 deletions swagger_py_codegen/templates/falcon/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from werkzeug.datastructures import MultiDict, Headers
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, security, base_path, normalize)
validators, filters, scopes, resolver, security, base_path, normalize)


if six.PY3:
Expand Down Expand Up @@ -44,7 +44,7 @@ class JSONEncoder(json.JSONEncoder):
class FalconValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema)
self.validator = Draft4Validator(schema, resolver=resolver)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -87,7 +87,7 @@ class FalconValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = {e.path[0]: e.message for e in self.validator.iter_errors(value)}
return normalize(self.validator.schema, value)[0], errors
return normalize(self.validator.schema, value, resolver=resolver)[0], errors


def request_validate(req, resp, resource, params):
Expand Down Expand Up @@ -154,15 +154,15 @@ def response_filter(req, resp, resource):
'Not defined',
description='`%d` is not a defined status code.' % status)

_resp, errors = normalize(schemas['schema'], req.context['result'])
_resp, errors = normalize(schemas['schema'], req.context['result'], resolver=resolver)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers)
{'properties': schemas['headers']}, headers, resolver=resolver)
errors.extend(header_errors)
if errors:
raise falcon.HTTPInternalServerError(title='Expectation Failed',
description=errors)

if 'result' not in req.context:
return
resp.body = json.dumps(_resp)
resp.body = json.dumps(_resp)
10 changes: 5 additions & 5 deletions swagger_py_codegen/templates/flask/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from flask_restful.utils import unpack
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, security, merge_default, normalize)
validators, filters, scopes, resolver, security, merge_default, normalize)


class JSONEncoder(json.JSONEncoder):
Expand All @@ -29,7 +29,7 @@ class JSONEncoder(json.JSONEncoder):
class FlaskValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema)
self.validator = Draft4Validator(schema, resolver=resolver)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -72,7 +72,7 @@ class FlaskValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value)[0], errors
return normalize(self.validator.schema, value, resolver=resolver)[0], errors


def request_validate(view):
Expand Down Expand Up @@ -136,10 +136,10 @@ def response_filter(view):
# return resp, status, headers
abort(500, message='`%d` is not a defined status code.' % status)

resp, errors = normalize(schemas['schema'], resp)
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers)
{'properties': schemas['headers']}, headers, resolver=resolver)
errors.extend(header_errors)
if errors:
abort(500, message='Expectation Failed', errors=errors)
Expand Down
8 changes: 5 additions & 3 deletions swagger_py_codegen/templates/jsonschema/schemas.tpl
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# -*- coding: utf-8 -*-

import six
from jsonschema import RefResolver
from swagger_py_codegen.parser import RefNode

# TODO: datetime support


{% include '_do_not_change.tpl' %}

base_path = '{{base_path}}'

{% for name, value in schemas.items() %}
{{ name }} = {{ value }}
{%- endfor %}
definitions = {{ definitions }}

validators = {
{%- for name, value in validators.items() %}
Expand All @@ -30,6 +31,7 @@ scopes = {
{%- endfor %}
}

resolver = RefResolver.from_schema(definitions)

class Security(object):

Expand Down
10 changes: 5 additions & 5 deletions swagger_py_codegen/templates/sanic/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from sanic.request import RequestParameters
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, security, base_path, normalize, current)
validators, filters, scopes, security, resolver, base_path, normalize, current)


def unpack(value):
Expand Down Expand Up @@ -63,7 +63,7 @@ class JSONEncoder(json.JSONEncoder):
class SanicValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema)
self.validator = Draft4Validator(schema, resolver=resolver)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -106,7 +106,7 @@ class SanicValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value)[0], errors
return normalize(self.validator.schema, value, resolver=resolver)[0], errors


def request_validate(view):
Expand Down Expand Up @@ -175,10 +175,10 @@ def response_filter(view):
# return resp, status, headers
raise ServerError('`%d` is not a defined status code.' % status, 500)

resp, errors = normalize(schemas['schema'], resp)
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers)
{'properties': schemas['headers']}, headers, resolver=resolver)
errors.extend(header_errors)
if errors:
raise ServerError('Expectation Failed', 500)
Expand Down
12 changes: 6 additions & 6 deletions swagger_py_codegen/templates/tornado/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import six
from functools import wraps
from jsonschema import Draft4Validator

from .schemas import validators, scopes, normalize, filters
from .schemas import validators, scopes, resolver, normalize, filters


class ValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema)
self.validator = Draft4Validator(schema, resolver=resolver)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -66,7 +66,7 @@ class ValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value)[0], errors
return normalize(self.validator.schema, value, resolver=resolver)[0], errors

def request_validate(obj):
def _request_validate(view):
Expand Down Expand Up @@ -134,10 +134,10 @@ def response_filter(obj):
raise tornado.web.HTTPError(
500, message='`%d` is not a defined status code.' % status)

resp, errors = normalize(schemas['schema'], resp)
resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers)
{'properties': schemas['headers']}, headers, resolver=resolver)
errors.extend(header_errors)
if errors:
raise tornado.web.HTTPError(
Expand Down Expand Up @@ -167,4 +167,4 @@ def unpack(value):
except ValueError:
pass

return value, 200, {}
return value, 200, {}
Loading

0 comments on commit 278b546

Please sign in to comment.