Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dependencies] Fix, support marshmallow 3 #1334

Merged
merged 17 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions docs/rest_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1294,19 +1294,20 @@ And we get an HTTP 422 (Unprocessable Entity).
How to add custom validation? On our next example we only allow
group names that start with a capital "A"::

from marshmallow import Schema, fields, ValidationError, post_load
from flask_appbuilder.api.schemas import BaseModelSchema


def validate_name(n):
if n[0] != 'A':
raise ValidationError('Name must start with an A')

class GroupCustomSchema(Schema):
class GroupCustomSchema(BaseModelSchema):
model_cls = ContactGroup
name = fields.Str(validate=validate_name)

@post_load
def process(self, data):
return ContactGroup(**data)
Note that `BaseModelSchema` extends marshmallow `Schema` class, to support automatic SQLAlchemy model creation and
update, it's a lighter version of marshmallow-sqlalchemy `ModelSchema`. Declare your SQLAlchemy model on `model_cls`
so that a model is created on schema load.

Then on our Api class::

Expand Down
90 changes: 44 additions & 46 deletions flask_appbuilder/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import Dict, Optional
import urllib.parse

from apispec import yaml_utils
from apispec import APISpec, yaml_utils
from apispec.exceptions import DuplicateComponentNameError
from flask import Blueprint, current_app, jsonify, make_response, request, Response
from flask_babel import lazy_gettext as _
import jsonschema
Expand Down Expand Up @@ -484,7 +485,8 @@ def create_blueprint(self, appbuilder, endpoint=None, static_folder=None):
self._register_urls()
return self.blueprint

def add_api_spec(self, api_spec):
def add_api_spec(self, api_spec: APISpec) -> None:
self.add_apispec_components(api_spec)
for attr_name in dir(self):
attr = getattr(self, attr_name)
if hasattr(attr, "_urls"):
Expand All @@ -509,27 +511,17 @@ def add_api_spec(self, api_spec):
self.openapi_spec_tag or self.__class__.__name__
)
api_spec._paths[path][operation]["tags"] = [openapi_spec_tag]
self.add_apispec_components(api_spec)

def add_apispec_components(self, api_spec):
def add_apispec_components(self, api_spec: APISpec) -> None:
for k, v in self.responses.items():
api_spec.components._responses[k] = v
for k, v in self._apispec_parameter_schemas.items():
if k not in api_spec.components._parameters:
_v = {
"in": "query",
"name": API_URI_RIS_KEY,
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/{}".format(k)}
}
},
}
# Using private because parameter method does not behave correctly
api_spec.components._schemas[k] = v
api_spec.components._parameters[k] = _v
try:
api_spec.components.schema(k, v)
except DuplicateComponentNameError:
pass

def _register_urls(self):
def _register_urls(self) -> None:
for attr_name in dir(self):
if (
self.include_route_methods is not None
Expand All @@ -547,9 +539,11 @@ def _register_urls(self):
)
self.blueprint.add_url_rule(url, attr_name, attr, methods=methods)

def path_helper(self, path=None, operations=None, **kwargs):
def path_helper(
self, path: str = None, operations: Dict[str, Dict] = None, **kwargs
) -> str:
"""
Works like a apispec plugin
Works like an apispec plugin
May return a path as string and mutate operations dict.

:param str path: Path to the resource
Expand All @@ -561,7 +555,7 @@ def path_helper(self, path=None, operations=None, **kwargs):
"""
RE_URL = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>")
path = RE_URL.sub(r"{\1}", path)
return "/{}{}".format(self.resource_name, path)
return f"/{self.resource_name}{path}"

def operation_helper(
self, path=None, operations=None, methods=None, func=None, **kwargs
Expand Down Expand Up @@ -1248,7 +1242,12 @@ def info(self, **kwargs):
---
get:
parameters:
- $ref: '#/components/parameters/get_info_schema'
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/get_info_schema'
responses:
200:
description: Item from Model
Expand Down Expand Up @@ -1306,7 +1305,7 @@ def get_headless(self, pk, **kwargs) -> Response:
_show_model_schema = self.show_model_schema

_response["id"] = pk
_response[API_RESULT_RES_KEY] = _show_model_schema.dump(item, many=False).data
_response[API_RESULT_RES_KEY] = _show_model_schema.dump(item, many=False)
self.pre_get(_response)
return self.response(200, **_response)

Expand All @@ -1328,7 +1327,12 @@ def get(self, pk, **kwargs):
schema:
type: integer
name: pk
- $ref: '#/components/parameters/get_item_schema'
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/get_item_schema'
responses:
200:
description: Item from Model
Expand Down Expand Up @@ -1407,7 +1411,7 @@ def get_list_headless(self, **kwargs) -> Response:
select_columns=query_select_columns,
)
pks = self.datamodel.get_keys(lst)
_response[API_RESULT_RES_KEY] = _list_model_schema.dump(lst, many=True).data
_response[API_RESULT_RES_KEY] = _list_model_schema.dump(lst, many=True)
_response["ids"] = pks
_response["count"] = count
self.pre_get_list(_response)
Expand All @@ -1428,7 +1432,12 @@ def get_list(self, **kwargs):
---
get:
parameters:
- $ref: '#/components/parameters/get_list_schema'
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/get_list_schema'
responses:
200:
description: Items from Model
Expand Down Expand Up @@ -1484,19 +1493,15 @@ def post_headless(self) -> Response:
except ValidationError as err:
return self.response_422(message=err.messages)
# This validates custom Schema with custom validations
if isinstance(item.data, dict):
return self.response_422(message=item.errors)
self.pre_add(item.data)
self.pre_add(item)
try:
self.datamodel.add(item.data, raise_exception=True)
self.post_add(item.data)
self.datamodel.add(item, raise_exception=True)
self.post_add(item)
return self.response(
201,
**{
API_RESULT_RES_KEY: self.add_model_schema.dump(
item.data, many=False
).data,
"id": self.datamodel.get_pk_value(item.data),
API_RESULT_RES_KEY: self.add_model_schema.dump(item, many=False),
"id": self.datamodel.get_pk_value(item),
},
)
except IntegrityError as e:
Expand Down Expand Up @@ -1554,20 +1559,13 @@ def put_headless(self, pk) -> Response:
item = self.edit_model_schema.load(data, instance=item)
except ValidationError as err:
return self.response_422(message=err.messages)
# This validates custom Schema with custom validations
if isinstance(item.data, dict):
return self.response_422(message=item.errors)
self.pre_update(item.data)
self.pre_update(item)
try:
self.datamodel.edit(item.data, raise_exception=True)
self.datamodel.edit(item, raise_exception=True)
self.post_update(item)
return self.response(
200,
**{
API_RESULT_RES_KEY: self.edit_model_schema.dump(
item.data, many=False
).data
},
**{API_RESULT_RES_KEY: self.edit_model_schema.dump(item, many=False)},
)
except IntegrityError as e:
return self.response_422(message=str(e.orig))
Expand Down Expand Up @@ -1828,7 +1826,7 @@ def _merge_update_item(self, model_item, data):
:param data: python data structure
:return: python data structure
"""
data_item = self.edit_model_schema.dump(model_item, many=False).data
data_item = self.edit_model_schema.dump(model_item, many=False)
for _col in self.edit_columns:
if _col not in data.keys():
data[_col] = data_item[_col]
Expand Down
14 changes: 11 additions & 3 deletions flask_appbuilder/api/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from marshmallow import fields
from marshmallow_enum import EnumField
from marshmallow_sqlalchemy import field_for
from marshmallow_sqlalchemy.schema import ModelSchema
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema


class TreeNode:
Expand Down Expand Up @@ -92,19 +92,21 @@ def _meta_schema_factory(self, columns, model, class_mixin):
_model = model
if columns:

class MetaSchema(ModelSchema, class_mixin):
class MetaSchema(SQLAlchemyAutoSchema, class_mixin):
class Meta:
model = _model
fields = columns
strict = True
load_instance = True
sqla_session = self.datamodel.session

else:

class MetaSchema(ModelSchema, class_mixin):
class MetaSchema(SQLAlchemyAutoSchema, class_mixin):
class Meta:
model = _model
strict = True
load_instance = True
sqla_session = self.datamodel.session

return MetaSchema
Expand Down Expand Up @@ -168,12 +170,18 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False)
# is custom property method field?
if hasattr(getattr(_model, column.data), "fget"):
return fields.Raw(dump_only=True)
# its a model function
if hasattr(getattr(_model, column.data), "__call__"):
return fields.Function(getattr(_model, column.data), dump_only=True)
# is a normal model field not a function?
if not hasattr(getattr(_model, column.data), "__call__"):
field = field_for(_model, column.data)
field.unique = datamodel.is_unique(column.data)
if column.data in self.validators_columns:
if field.validate is None:
field.validate = []
field.validate.append(self.validators_columns[column.data])
field.validators.append(self.validators_columns[column.data])
return field

def convert(self, columns, model=None, nested=True, enum_dump_by_name=False):
Expand Down
32 changes: 32 additions & 0 deletions flask_appbuilder/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from marshmallow import post_load, Schema

from ..const import (
API_ADD_COLUMNS_RIS_KEY,
API_ADD_TITLE_RIS_KEY,
Expand All @@ -20,6 +22,36 @@
API_SHOW_TITLE_RIS_KEY,
)


class BaseModelSchema(Schema):
"""
Extends marshmallow Schema to add functionality similar to marshmallow-sqlalchemy
for creating and updating SQLAlchemy models on load
"""

model_cls = None
"""Declare the SQLAlchemy model when creating a new model on load"""

def __init__(self, *arg, **kwargs):
super().__init__()
self.instance = None

@post_load
def process(self, data, **kwargs):
if self.instance is not None:
for key, value in data.items():
setattr(self.instance, key, value)
return self.instance
return self.model_cls(**data)

def load(self, data, *, instance=None, **kwargs):
self.instance = instance
try:
return super().load(data, **kwargs)
finally:
self.instance = None


get_list_schema = {
"type": "object",
"properties": {
Expand Down
17 changes: 9 additions & 8 deletions flask_appbuilder/tests/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import enum

from flask_appbuilder import Model
from marshmallow import fields, post_load, Schema, ValidationError
from flask_appbuilder.api.schemas import BaseModelSchema
from marshmallow import fields, ValidationError
from sqlalchemy import (
Column,
Date,
Expand Down Expand Up @@ -44,12 +45,12 @@ def validate_field_string(n):
raise ValidationError("Name must start with an A")


class Model1CustomSchema(Schema):
name = fields.Str(validate=validate_name)

@post_load
def process(self, data):
return Model1(**data)
class Model1CustomSchema(BaseModelSchema):
model_cls = Model1
field_string = fields.String(validate=validate_name)
field_integer = fields.Integer(allow_none=True)
field_float = fields.Float(allow_none=True)
field_date = fields.Date(allow_none=True)


class Model2(Model):
Expand All @@ -67,7 +68,7 @@ def __repr__(self):
return str(self.field_string)

def field_method(self):
return "field_method_value"
return f"{self.field_string}_field_method"


class Model3(Model):
Expand Down
Loading