Skip to content

Commit

Permalink
Decouple parameters formatting and endpoint logic (#9405)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Jun 19, 2020
1 parent 416334e commit d7ef352
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 51 deletions.
22 changes: 15 additions & 7 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from sqlalchemy import func

from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection, dagrun_collection_schema, dagrun_schema,
)
from airflow.api_connexion.utils import conn_parse_datetime
from airflow.models import DagRun
from airflow.utils.session import provide_session

Expand All @@ -47,6 +47,14 @@ def get_dag_run(dag_id, dag_run_id, session):
return dagrun_schema.dump(dag_run)


@format_parameters({
'start_date_gte': format_datetime,
'start_date_lte': format_datetime,
'execution_date_gte': format_datetime,
'execution_date_lte': format_datetime,
'end_date_gte': format_datetime,
'end_date_lte': format_datetime,
})
@provide_session
def get_dag_runs(session, dag_id, start_date_gte=None, start_date_lte=None,
execution_date_gte=None, execution_date_lte=None,
Expand All @@ -63,24 +71,24 @@ def get_dag_runs(session, dag_id, start_date_gte=None, start_date_lte=None,

# filter start date
if start_date_gte:
query = query.filter(DagRun.start_date >= conn_parse_datetime(start_date_gte))
query = query.filter(DagRun.start_date >= start_date_gte)

if start_date_lte:
query = query.filter(DagRun.start_date <= conn_parse_datetime(start_date_lte))
query = query.filter(DagRun.start_date <= start_date_lte)

# filter execution date
if execution_date_gte:
query = query.filter(DagRun.execution_date >= conn_parse_datetime(execution_date_gte))
query = query.filter(DagRun.execution_date >= execution_date_gte)

if execution_date_lte:
query = query.filter(DagRun.execution_date <= conn_parse_datetime(execution_date_lte))
query = query.filter(DagRun.execution_date <= execution_date_lte)

# filter end date
if end_date_gte:
query = query.filter(DagRun.end_date >= conn_parse_datetime(end_date_gte))
query = query.filter(DagRun.end_date >= end_date_gte)

if end_date_lte:
query = query.filter(DagRun.end_date <= conn_parse_datetime(end_date_lte))
query = query.filter(DagRun.end_date <= end_date_lte)

# apply offset and limit
dag_run = query.order_by(DagRun.id).offset(offset).limit(limit).all()
Expand Down
45 changes: 44 additions & 1 deletion airflow/api_connexion/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,54 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from functools import wraps
from typing import Callable, Dict

# Pagination parameters
from pendulum.parsing import ParserError

from airflow.api_connexion.exceptions import BadRequest
from airflow.utils import timezone

# Page parameters
page_offset = "offset"
page_limit = "limit"

# Database entity fields
dag_id = "dag_id"
pool_id = "pool_id"


def format_datetime(value: str):
"""
Datetime format parser for args since connexion doesn't parse datetimes
https://github.com/zalando/connexion/issues/476
This should only be used within connection views because it raises 400
"""
if value[-1] != 'Z':
value = value.replace(" ", '+')
try:
return timezone.parse(value)
except (ParserError, TypeError) as err:
raise BadRequest(
"Incorrect datetime argument", detail=str(err)
)


def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
"""
Decorator factory that create decorator that convert parameters using given formatters.
Using it allows you to separate parameter formatting from endpoint logic.
:param params_formatters: Map of key name and formatter function
"""
def format_parameters_decorator(func):
@wraps(func)
def wrapped_function(*args, **kwargs):
for key, formatter in params_formatters.items():
if key in kwargs:
kwargs[key] = formatter(kwargs[key])
return func(*args, **kwargs)
return wrapped_function
return format_parameters_decorator
39 changes: 0 additions & 39 deletions airflow/api_connexion/utils.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
# under the License.

import unittest
from unittest import mock

from pendulum import DateTime
from pendulum.tz.timezone import Timezone

from airflow.api_connexion.exceptions import BadRequest
from airflow.api_connexion.utils import conn_parse_datetime
from airflow.api_connexion.parameters import format_datetime, format_parameters
from airflow.utils import timezone


Expand All @@ -29,18 +33,37 @@ def setUp(self) -> None:
self.default_time_2 = '2020-06-13T22:44:00Z'

def test_works_with_datestring_ending_00_00(self):
datetime = conn_parse_datetime(self.default_time)
datetime = format_datetime(self.default_time)
datetime2 = timezone.parse(self.default_time)
assert datetime == datetime2
assert datetime.isoformat() == self.default_time

def test_works_with_datestring_ending_with_zed(self):
datetime = conn_parse_datetime(self.default_time_2)
datetime = format_datetime(self.default_time_2)
datetime2 = timezone.parse(self.default_time_2)
assert datetime == datetime2
assert datetime.isoformat() == self.default_time # python uses +00:00 instead of Z

def test_raises_400_for_invalid_arg(self):
invalid_datetime = '2020-06-13T22:44:00P'
with self.assertRaises(BadRequest):
conn_parse_datetime(invalid_datetime)
format_datetime(invalid_datetime)


class TestFormatParameters(unittest.TestCase):

def test_should_works_with_datetime_formatter(self):
decorator = format_parameters({"param_a": format_datetime})
endpoint = mock.MagicMock()
decorated_endpoint = decorator(endpoint)

decorated_endpoint(param_a='2020-01-01T0:0:00+00:00')

endpoint.assert_called_once_with(param_a=DateTime(2020, 1, 1, 0, tzinfo=Timezone('UTC')))

def test_should_propagate_exceptions(self):
decorator = format_parameters({"param_a": format_datetime})
endpoint = mock.MagicMock()
decorated_endpoint = decorator(endpoint)
with self.assertRaises(BadRequest):
decorated_endpoint(param_a='XXXXX')

0 comments on commit d7ef352

Please sign in to comment.