Skip to content

Commit

Permalink
Revert "resolver"
Browse files Browse the repository at this point in the history
  • Loading branch information
gusibi authored Aug 29, 2018
1 parent fd02724 commit c309463
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 104 deletions.
30 changes: 10 additions & 20 deletions swagger_py_codegen/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import getsource

from .base import Code, CodeGenerator
from .parser import RefNode
from .parser import schema_var_name


class Schema(Code):
Expand Down Expand Up @@ -89,8 +89,10 @@ def build_data(swagger):
scopes[(endpoint, method)] = list(security.values()).pop()
break

schemas = OrderedDict([(schema_var_name(path), swagger.get(path)) for path in swagger.definitions])

data = dict(
definitions={'definitions':swagger.origin_data.get('definitions', {})},
schemas=schemas,
validators=validators,
filters=filters,
scopes=scopes,
Expand All @@ -107,7 +109,7 @@ def _process(self):
yield Schema(build_data(self.swagger))


def merge_default(schema, value, get_first=True, resolver=None):
def merge_default(schema, value, get_first=True):
# TODO: more types support
type_defaults = {
'integer': 9573,
Expand All @@ -117,17 +119,17 @@ def merge_default(schema, value, get_first=True, resolver=None):
'boolean': False
}

results = normalize(schema, value, type_defaults, resolver=resolver)
results = normalize(schema, value, type_defaults)
if get_first:
return results[0]
return results


def build_default(schema, resolver=None):
return merge_default(schema, None, resolver=resolver)
def build_default(schema):
return merge_default(schema, None)


def normalize(schema, data, required_defaults=None, resolver=None):
def normalize(schema, data, required_defaults=None):
if required_defaults is None:
required_defaults = {}
errors = []
Expand Down Expand Up @@ -215,7 +217,7 @@ def _normalize_dict(schema, data):

def _normalize_list(schema, data):
result = []
if hasattr(data, '__iter__') and not isinstance(data, (dict, RefNode)):
if hasattr(data, '__iter__') and not isinstance(data, dict):
for item in data:
result.append(_normalize(schema.get('items'), item))
elif 'default' in schema:
Expand All @@ -228,15 +230,6 @@ def _normalize_default(schema, data):
else:
return data

def _normalize_ref(schema, data):
if resolver == None:
raise TypeError("resolver must be provided")
ref = schema.get(u"$ref")
scope, resolved = resolver.resolve(ref)
return _normalize(resolved, data)



def _normalize(schema, data):
if schema is True or schema == {}:
return data
Expand All @@ -246,13 +239,10 @@ def _normalize(schema, data):
'object': _normalize_dict,
'array': _normalize_list,
'default': _normalize_default,
'ref': _normalize_ref
}
type_ = schema.get('type', 'object')
if type_ not in funcs:
type_ = 'default'
if schema.get(u'$ref', None):
type_ = 'ref'

return funcs[type_](schema, data)

Expand Down
50 changes: 14 additions & 36 deletions swagger_py_codegen/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,14 @@ def schema_var_name(path):
return ''.join(map(str.capitalize, map(str, path)))


class RefNode(object):
class RefNode(dict):

def __init__(self, data, ref):
self.ref = ref
self._data = data

def __getitem__(self, key):
return self._data.__getitem__(key)

def __settiem__(self, key, value):
return self._data.__settiem__(key, value)

def get(self, key, default=None):
return self._data.get(key, default)

def has(self, key, default=None):
return self._data.has(key)

def keys(self):
return self._data.keys()

def __iter__(self):
return self._data.__iter__()
super(RefNode, self).__init__(data)

def __repr__(self):
return repr({'$ref':self.ref})

def __eq__(self, other):
if isinstance(other, RefNode):
return self._data == other._data and self.ref == other.ref
else:
return object.__eq__(other)
return schema_var_name(self.ref)


class Swagger(object):
Expand All @@ -64,10 +40,14 @@ def _process_ref(self):
"""
resolve all references util no reference exists
"""
for path, ref in self.search(['**', '$ref']):
data = resolve(self.data, ref)
path = path[:-1]
self.set(path, RefNode(data, ref))
while 1:
li = list(self.search(['**', '$ref']))
if not li:
break
for path, ref in li:
data = resolve(self.data, ref)
path = path[:-1]
self.set(path, data)

def _resolve_definitions(self):
"""
Expand Down Expand Up @@ -96,19 +76,17 @@ def get_definition_refs():
while definition_refs:
ready = {
definition for definition, refs
in six.iteritems(definition_refs)
in six.iteritems(definition_refs) if not refs
}
if not ready:
continue
#msg = '$ref circular references found!\n'
#raise ValueError(msg)
msg = '$ref circular references found!\n'
raise ValueError(msg)
for definition in ready:
del definition_refs[definition]
for refs in six.itervalues(definition_refs):
refs.difference_update(ready)

self._definitions += ready
self._definitions.sort(key=lambda x :x[1])

def search(self, path):
for p, d in dpath.util.search(
Expand Down
12 changes: 6 additions & 6 deletions swagger_py_codegen/templates/falcon/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from werkzeug.datastructures import MultiDict, Headers
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, resolver, security, base_path, normalize)
validators, filters, scopes, security, base_path, normalize)


if six.PY3:
Expand Down Expand Up @@ -44,7 +44,7 @@ class JSONEncoder(json.JSONEncoder):
class FalconValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema, resolver=resolver)
self.validator = Draft4Validator(schema)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -87,7 +87,7 @@ class FalconValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = {e.path[0]: e.message for e in self.validator.iter_errors(value)}
return normalize(self.validator.schema, value, resolver=resolver)[0], errors
return normalize(self.validator.schema, value)[0], errors


def request_validate(req, resp, resource, params):
Expand Down Expand Up @@ -154,15 +154,15 @@ def response_filter(req, resp, resource):
'Not defined',
description='`%d` is not a defined status code.' % status)

_resp, errors = normalize(schemas['schema'], req.context['result'], resolver=resolver)
_resp, errors = normalize(schemas['schema'], req.context['result'])
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers, resolver=resolver)
{'properties': schemas['headers']}, headers)
errors.extend(header_errors)
if errors:
raise falcon.HTTPInternalServerError(title='Expectation Failed',
description=errors)

if 'result' not in req.context:
return
resp.body = json.dumps(_resp)
resp.body = json.dumps(_resp)
10 changes: 5 additions & 5 deletions swagger_py_codegen/templates/flask/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from flask_restful.utils import unpack
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, resolver, security, merge_default, normalize)
validators, filters, scopes, security, merge_default, normalize)


class JSONEncoder(json.JSONEncoder):
Expand All @@ -29,7 +29,7 @@ class JSONEncoder(json.JSONEncoder):
class FlaskValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema, resolver=resolver)
self.validator = Draft4Validator(schema)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -72,7 +72,7 @@ class FlaskValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value, resolver=resolver)[0], errors
return normalize(self.validator.schema, value)[0], errors


def request_validate(view):
Expand Down Expand Up @@ -136,10 +136,10 @@ def response_filter(view):
# return resp, status, headers
abort(500, message='`%d` is not a defined status code.' % status)

resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
resp, errors = normalize(schemas['schema'], resp)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers, resolver=resolver)
{'properties': schemas['headers']}, headers)
errors.extend(header_errors)
if errors:
abort(500, message='Expectation Failed', errors=errors)
Expand Down
8 changes: 3 additions & 5 deletions swagger_py_codegen/templates/jsonschema/schemas.tpl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# -*- coding: utf-8 -*-

import six
from jsonschema import RefResolver
from swagger_py_codegen.parser import RefNode

# TODO: datetime support


{% include '_do_not_change.tpl' %}

base_path = '{{base_path}}'

definitions = {{ definitions }}
{% for name, value in schemas.items() %}
{{ name }} = {{ value }}
{%- endfor %}

validators = {
{%- for name, value in validators.items() %}
Expand All @@ -31,7 +30,6 @@ scopes = {
{%- endfor %}
}

resolver = RefResolver.from_schema(definitions)

class Security(object):

Expand Down
10 changes: 5 additions & 5 deletions swagger_py_codegen/templates/sanic/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from sanic.request import RequestParameters
from jsonschema import Draft4Validator

from .schemas import (
validators, filters, scopes, security, resolver, base_path, normalize, current)
validators, filters, scopes, security, base_path, normalize, current)


def unpack(value):
Expand Down Expand Up @@ -63,7 +63,7 @@ class JSONEncoder(json.JSONEncoder):
class SanicValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema, resolver=resolver)
self.validator = Draft4Validator(schema)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -106,7 +106,7 @@ class SanicValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value, resolver=resolver)[0], errors
return normalize(self.validator.schema, value)[0], errors


def request_validate(view):
Expand Down Expand Up @@ -175,10 +175,10 @@ def response_filter(view):
# return resp, status, headers
raise ServerError('`%d` is not a defined status code.' % status, 500)

resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
resp, errors = normalize(schemas['schema'], resp)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers, resolver=resolver)
{'properties': schemas['headers']}, headers)
errors.extend(header_errors)
if errors:
raise ServerError('Expectation Failed', 500)
Expand Down
12 changes: 6 additions & 6 deletions swagger_py_codegen/templates/tornado/validators.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import six
from functools import wraps
from jsonschema import Draft4Validator

from .schemas import validators, scopes, resolver, normalize, filters
from .schemas import validators, scopes, normalize, filters


class ValidatorAdaptor(object):

def __init__(self, schema):
self.validator = Draft4Validator(schema, resolver=resolver)
self.validator = Draft4Validator(schema)

def validate_number(self, type_, value):
try:
Expand Down Expand Up @@ -66,7 +66,7 @@ class ValidatorAdaptor(object):
def validate(self, value):
value = self.type_convert(value)
errors = list(e.message for e in self.validator.iter_errors(value))
return normalize(self.validator.schema, value, resolver=resolver)[0], errors
return normalize(self.validator.schema, value)[0], errors

def request_validate(obj):
def _request_validate(view):
Expand Down Expand Up @@ -134,10 +134,10 @@ def response_filter(obj):
raise tornado.web.HTTPError(
500, message='`%d` is not a defined status code.' % status)

resp, errors = normalize(schemas['schema'], resp, resolver=resolver)
resp, errors = normalize(schemas['schema'], resp)
if schemas['headers']:
headers, header_errors = normalize(
{'properties': schemas['headers']}, headers, resolver=resolver)
{'properties': schemas['headers']}, headers)
errors.extend(header_errors)
if errors:
raise tornado.web.HTTPError(
Expand Down Expand Up @@ -167,4 +167,4 @@ def unpack(value):
except ValueError:
pass

return value, 200, {}
return value, 200, {}
Loading

0 comments on commit c309463

Please sign in to comment.