Skip to content

Commit

Permalink
refactor schema validation to handle arrays with object/ref/union ele…
Browse files Browse the repository at this point in the history
…ments

for #3
  • Loading branch information
snarfed committed Sep 26, 2024
1 parent 511098c commit a19cfe4
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 142 deletions.
252 changes: 122 additions & 130 deletions lexrpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
'bytes': bytes,
'array': list,
'object': dict,
# these could be tokens
# 'ref': dict,
# 'union': dict,
}

# https://atproto.com/specs/data-model#blob-type
Expand Down Expand Up @@ -270,159 +273,148 @@ def validate(self, nsid, type, obj):
}

if self._validate:
self._validate_value(obj, nsid, schema)
self._validate_schema(obj, nsid, name=type, schema=schema)

return obj

def _validate_value(self, obj, type_, lexicon):
"""Validates an ATProto object against a lexicon.
def _validate_schema(self, val, type_, name=None, schema=None):
"""Validates an ATProto value against a lexicon schema.
Returns ``None`` if the object validates, otherwise raises an exception.
Returns ``None`` if the value validates, otherwise raises an exception.
https://atproto.com/specs/lexicon
Args:
obj (dict)
type_ (str): name of type, eg ``app.bsky.feed.post#replyRef``
lexicon (dict): should have at least ``properties`` key
val
type_ (str): name of type, eg ``integer`` or ``app.bsky.feed.post#replyRef``
name (str): field name
schema (dict): schema to validate against if this is a compound
object and not a primitive
Raises:
ValidationError: if the object is invalid
ValidationError: if the value is invalid
"""
assert lexicon

def trunc(val):
def fail(msg):
val_str = repr(val)
return val_str if len(val_str) <= 50 else val_str[:50] + '…'
if len(val_str) > 50:
val_str = val_str[:50] + '…'
raise ValidationError(f'{type_} {name} with value {val_str} {msg}')

if lexicon.get('type') == 'token':
if obj != type_:
raise ValidationError(f'got value {trunc(obj)} for type token')
# TODO: anything else to do here?
return

if not isinstance(obj, dict):
raise ValidationError(f'expected object for {type_} property; got {trunc(obj)}')

for name, schema in lexicon.get('properties', {}).items():
if name not in obj:
if name in lexicon.get('required', []):
raise ValidationError(f'{type_} missing required property {name}')
continue
if const := schema.get('const'):
if val != const:
fail(f'is not const value {const}')

prop_type = schema['type']
if enums := schema.get('enum'):
if val not in enums:
fail('is not one of enum values')

def fail(msg):
raise ValidationError(f'{prop_type} property {name} value {trunc(val)} {msg}')
if type_ == 'unknown':
return

val = obj[name]
if val is None:
if prop_type != 'null' and name not in lexicon.get('nullable', []):
fail('is not nullable')
continue
if expected := FIELD_TYPES.get(type_):
if type(val) != expected:
fail(f'has unexpected type {type(val).__name__}')

if prop_type == 'unknown':
if type_ in ('array', 'bytes', 'string'):
min_length = schema.get('minLength')
max_length = schema.get('maxLength')
length = len(val.encode('utf-8') if type_ == 'string' else val)
if max_length and length > max_length:
fail(f'is longer ({length}) than maxLength {max_length}')
elif min_length and length < min_length:
fail(f'is shorter ({length}) than minLength {min_length}')

if type_ == 'string':
if format := schema.get('format'):
try:
self._validate_string_format(val, format)
except ValidationError as e:
fail(e.args[0])

min_graphemes = schema.get('minGraphemes')
max_graphemes = schema.get('maxGraphemes')
if min_graphemes or max_graphemes:
length = grapheme.length(val)
if min_graphemes and length < min_graphemes:
fail(f'is shorter than minGraphemes {min_graphemes}')
if max_graphemes and length > max_graphemes:
fail(f'is longer than maxGraphemes {max_graphemes}')

if minimum := schema.get('minimum'):
if val < minimum:
fail(f'is lower than minimum {minimum}')
if maximum := schema.get('maximum'):
if val > maximum:
fail(f'is higher than maximum {maximum}')

if schema and schema.get('type') == 'token':
if val != type_:
fail(f'is not token {type_}')
elif val not in self.defs:
fail(f'not found')

if type_ in ('ref', 'union'):
if not isinstance(val, (str, dict)):
fail("is invalid")

if type_ == 'ref':
inner_type = schema['ref']
else:
inner_type = val.get('$type') if isinstance(val, dict) else val
if not inner_type:
fail('missing $type')
refs = schema['refs']
if inner_type not in refs:
fail(f"isn't one of {refs}")

# if it's a fragment, fully qualify it
schema = self._get_def(urljoin(type_, inner_type))

if type_ == 'blob':
max_size = schema.get('maxSize')
if max_size and val['size'] > max_size:
fail(f'has size {val["size"]} over maxSize {max_size}')

accept = schema.get('accept')
mime = val['mimeType']
if (accept and mime not in accept and '*/*' not in accept
and (mime.split('/')[0] + '/*') not in accept):
fail(f'has unsupported MIME type {mime}')

if type_ == 'array':
for item in val:
self._validate_schema(item, schema['items']['type'], name=name,
schema=schema['items'])

props = schema.get('properties', {})
if props and not isinstance(val, dict):
fail('should be object')

required = schema.get('required', [])
nullable = schema.get('nullable', [])
for inner_name, inner_schema in props.items():
if inner_name not in val:
if inner_name in required:
fail(f'missing required property {inner_name}')
continue

if prop_type == 'token':
if val not in self.defs:
fail(f'not found')

if prop_type in ('blob', 'object', 'ref', 'union'):
if prop_type == 'object':
inner_type = type_

elif prop_type == 'blob':
inner_type = 'blob'
max_size = schema.get('maxSize')
accept = schema.get('accept')

elif prop_type == 'ref':
inner_type = schema['ref']

elif prop_type == 'union':
inner_type = (val if isinstance(val, str) # token
else val.get('$type'))
if not inner_type:
fail('missing $type')
refs = schema['refs']
if (not isinstance(val, (str, dict))
or isinstance(val, str) and val not in refs):
fail("is invalid")

if prop_type != 'object':
# if it's a fragment, fully qualify it
inner_type = urljoin(type_, inner_type)
schema = self._get_def(inner_type)

if schema.get('type') == 'token':
if val != inner_type:
fail('is not token value')
elif not isinstance(val, dict):
fail('is invalid')

self._validate_value(val, inner_type, schema)

if prop_type == 'blob':
if max_size and val['size'] > max_size:
fail(f'has size {val["size"]} over maxSize {max_size}')

mime = val['mimeType']
if (accept and mime not in accept and '*/*' not in accept
and (mime.split('/')[0] + '/*') not in accept):
fail(f'has unsupported MIME type {mime}')

inner_type = inner_schema['type']
inner_val = val[inner_name]
if inner_val is None:
if inner_type != 'null' and inner_name not in nullable:
fail(f'property {inner_name} is not nullable')
continue

if type(val) is not FIELD_TYPES[prop_type]:
fail(f'has unexpected type {type(val).__name__}')
if inner_type == 'ref':
# if it's a fragment, fully qualify it
inner_type = urljoin(type_, inner_schema['ref'])
inner_schema = self._get_def(inner_type)
# TODO: union

if minimum := schema.get('minimum'):
if val < minimum:
fail(f'is less than minimum {minimum}')
if maximum := schema.get('maximum'):
if val > maximum:
fail(f'is longer than maximum {maximum}')

if prop_type in ('array', 'bytes', 'string'):
min_length = schema.get('minLength')
max_length = schema.get('maxLength')
length = len(val.encode('utf-8') if prop_type == 'string' else val)
if max_length and length > max_length:
fail(f'is longer ({length}) than maxLength {max_length}')
elif min_length and length < min_length:
fail(f'is shorter ({length}) than minLength {min_length}')

if prop_type == 'string':
if format := schema.get('format'):
try:
self._validate_string_format(val, format)
except ValidationError as e:
fail(e.args[0])

min_graphemes = schema.get('minGraphemes')
max_graphemes = schema.get('maxGraphemes')
if min_graphemes or max_graphemes:
length = grapheme.length(val)
if min_graphemes and length < min_graphemes:
fail(f'is shorter than minGraphemes {min_graphemes}')
if max_graphemes and length > max_graphemes:
fail(f'is longer than maxGraphemes {max_graphemes}')

if prop_type == 'array':
for item in val:
if type(item) is not FIELD_TYPES[schema['items']['type']]:
fail(f'has element {trunc(item)} with invalid type {type(item).__name__}')

if enums := schema.get('enum'):
if val not in enums:
fail('is not one of enum values')

if const := schema.get('const'):
if val != const:
fail(f'is not const value {const}')
self._validate_schema(inner_val, inner_type, name=inner_name,
schema=inner_schema)

return obj
return val

def _validate_string_format(self, val, format):
"""Validates an ATProto string value against a format.
Expand Down
2 changes: 1 addition & 1 deletion lexrpc/tests/lexicons.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@
'main': {
'type': 'record',
'record': {
'required': ['bar'],
'properties': {
'foo': {
'type': 'array',
'items': {
'type': 'object',
'required': ['bar'],
'properties': {
'bar': {'type': 'integer'},
'baj': {'type': 'string'},
Expand Down
27 changes: 16 additions & 11 deletions lexrpc/tests/test_flask_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,37 +227,42 @@ def subscribe(**kwargs):
def test_procedure_missing_input(self):
resp = self.client.post('/xrpc/io.example.procedure')
self.assertEqual(400, resp.status_code)
self.assertEqual('io.example.procedure missing required property foo',
resp.json['message'])
self.assertEqual(
'io.example.procedure input with value {} missing required property foo',
resp.json['message'])

resp = self.client.post('/xrpc/io.example.procedure', json={'bar': 3})
self.assertEqual(400, resp.status_code)
self.assertEqual('io.example.procedure missing required property foo',
resp.json['message'])
self.assertEqual(
"io.example.procedure input with value {'bar': 3} missing required property foo",
resp.json['message'])

def test_procedure_bad_input(self):
resp = self.client.post('/xrpc/io.example.procedure',
json={'foo': 2, 'bar': 3})
self.assertEqual(400, resp.status_code)
self.assertEqual('string property foo value 2 has unexpected type int',
self.assertEqual('string foo with value 2 has unexpected type int',
resp.json['message'])

def test_query_bad_output(self):
resp = self.client.get('/xrpc/io.example.query?foo=abc')
self.assertEqual(400, resp.status_code)
self.assertEqual('string property foo value None is not nullable',
resp.json['message'])
self.assertEqual(
"io.example.query output with value {'foo': None, 'bar': 5} property foo is not nullable",
resp.json['message'])

def test_missing_params(self):
resp = self.client.post('/xrpc/io.example.params')
self.assertEqual(400, resp.status_code)
self.assertEqual('io.example.params missing required property bar',
resp.json['message'])
self.assertEqual(
'io.example.params parameters with value {} missing required property bar',
resp.json['message'])

resp = self.client.post('/xrpc/io.example.params?foo=a')
self.assertEqual(400, resp.status_code)
self.assertEqual('io.example.params missing required property bar',
resp.json['message'])
self.assertEqual(
"io.example.params parameters with value {'foo': 'a'} missing required property bar",
resp.json['message'])

def test_raises_valueerror(self):
@server.method('io.example.valueError')
Expand Down

0 comments on commit a19cfe4

Please sign in to comment.