Skip to content

Commit

Permalink
Add read-only endpoints for DAG Model (#9045)
Browse files Browse the repository at this point in the history
Co-authored-by: Tomek Urbaszek <turbaszek@gmail.com>
Co-authored-by: Tomek Urbaszek <tomasz.urbaszek@polidea.com>
GitOrigin-RevId: 8b94ace597f47e350161d799b6b45aad80f45ae4
  • Loading branch information
3 people authored and Cloud Composer Team committed Sep 12, 2024
1 parent 4f370a3 commit 9c9e89e
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 15 deletions.
35 changes: 25 additions & 10 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from flask import current_app
from sqlalchemy import func

from airflow import DAG
from airflow.api_connexion.exceptions import NotFound
# TODO(mik-laj): We have to implement it.
# Do you want to help? Please look at:
# * https://github.com/apache/airflow/issues/8128
# * https://github.com/apache/airflow/issues/8138
from airflow.api_connexion.schemas.dag_schema import dag_detail_schema
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.dag_schema import (
DAGCollection, dag_detail_schema, dag_schema, dags_collection_schema,
)
from airflow.models.dag import DagModel
from airflow.utils.session import provide_session


def get_dag():
@provide_session
def get_dag(dag_id, session):
"""
Get basic information about a DAG.
"""
raise NotImplementedError("Not implemented yet.")
dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none()

if dag is None:
raise NotFound("DAG not found")

return dag_schema.dump(dag)


def get_dag_details(dag_id):
Expand All @@ -43,11 +50,19 @@ def get_dag_details(dag_id):
return dag_detail_schema.dump(dag)


def get_dags():
@format_parameters({
'limit': check_limit
})
@provide_session
def get_dags(session, limit, offset=0):
"""
Get all DAGs.
"""
raise NotImplementedError("Not implemented yet.")
dags = session.query(DagModel).order_by(DagModel.dag_id).offset(offset).limit(limit).all()

total_entries = session.query(func.count(DagModel.dag_id)).scalar()

return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))


def patch_dag():
Expand Down
1 change: 1 addition & 0 deletions airflow/api_connexion/schemas/dag_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,5 @@ class DAGCollectionSchema(Schema):

dags_collection_schema = DAGCollectionSchema()
dag_schema = DAGSchema()

dag_detail_schema = DAGDetailSchema()
130 changes: 125 additions & 5 deletions tests/api_connexion/endpoints/test_dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
from datetime import datetime

import pytest
from parameterized import parameterized

from airflow import DAG
from airflow.models import DagBag
from airflow.models import DagBag, DagModel
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.session import provide_session
from airflow.www import app
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags

Expand Down Expand Up @@ -58,13 +60,41 @@ def setUp(self) -> None:
def tearDown(self) -> None:
self.clean_db()

@provide_session
def _create_dag_models(self, count, session=None):
for num in range(1, count + 1):
dag_model = DagModel(
dag_id=f"TEST_DAG_{num}",
fileloc=f"/tmp/dag_{num}.py",
schedule_interval="2 2 * * *"
)
session.add(dag_model)


class TestGetDag(TestDagEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
response = self.client.get("/api/v1/dags/1/")
self._create_dag_models(1)
response = self.client.get("/api/v1/dags/TEST_DAG_1")
assert response.status_code == 200

current_response = response.json
current_response["fileloc"] = "/tmp/test-dag.py"
self.assertEqual({
'dag_id': 'TEST_DAG_1',
'description': None,
'fileloc': '/tmp/test-dag.py',
'is_paused': False,
'is_subdag': False,
'owners': [],
'root_dag_id': None,
'schedule_interval': {'__type': 'CronExpression', 'value': '2 2 * * *'},
'tags': []
}, current_response)

def test_should_response_404(self):
response = self.client.get("/api/v1/dags/INVALID_DAG")
assert response.status_code == 404


class TestGetDagDetails(TestDagEndpoint):
def test_should_response_200(self):
Expand Down Expand Up @@ -133,11 +163,101 @@ def test_should_response_200_serialized(self):


class TestGetDags(TestDagEndpoint):
@pytest.mark.skip(reason="Not implemented yet")

def test_should_response_200(self):
response = self.client.get("/api/v1/dags/1")
self._create_dag_models(2)

response = self.client.get("api/v1/dags")

assert response.status_code == 200

self.assertEqual(
{
"dags": [
{
"dag_id": "TEST_DAG_1",
"description": None,
"fileloc": "/tmp/dag_1.py",
"is_paused": False,
"is_subdag": False,
"owners": [],
"root_dag_id": None,
"schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"},
"tags": [],
},
{
"dag_id": "TEST_DAG_2",
"description": None,
"fileloc": "/tmp/dag_2.py",
"is_paused": False,
"is_subdag": False,
"owners": [],
"root_dag_id": None,
"schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"},
"tags": [],
},
],
"total_entries": 2,
},
response.json,
)

@parameterized.expand(
[
("api/v1/dags?limit=1", ["TEST_DAG_1"]),
("api/v1/dags?limit=2", ["TEST_DAG_1", "TEST_DAG_10"]),
(
"api/v1/dags?offset=5",
[
"TEST_DAG_5",
"TEST_DAG_6",
"TEST_DAG_7",
"TEST_DAG_8",
"TEST_DAG_9",
],
),
(
"api/v1/dags?offset=0",
[
"TEST_DAG_1",
"TEST_DAG_10",
"TEST_DAG_2",
"TEST_DAG_3",
"TEST_DAG_4",
"TEST_DAG_5",
"TEST_DAG_6",
"TEST_DAG_7",
"TEST_DAG_8",
"TEST_DAG_9",
],
),
("api/v1/dags?limit=1&offset=5", ["TEST_DAG_5"]),
("api/v1/dags?limit=1&offset=1", ["TEST_DAG_10"]),
("api/v1/dags?limit=2&offset=2", ["TEST_DAG_2", "TEST_DAG_3"]),
]
)
def test_should_response_200_and_handle_pagination(self, url, expected_dag_ids):
self._create_dag_models(10)

response = self.client.get(url)

assert response.status_code == 200

dag_ids = [dag["dag_id"] for dag in response.json['dags']]

self.assertEqual(expected_dag_ids, dag_ids)
self.assertEqual(10, response.json['total_entries'])

def test_should_response_200_default_limit(self):
self._create_dag_models(101)

response = self.client.get("api/v1/dags")

assert response.status_code == 200

self.assertEqual(100, len(response.json['dags']))
self.assertEqual(101, response.json['total_entries'])


class TestPatchDag(TestDagEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
Expand Down

0 comments on commit 9c9e89e

Please sign in to comment.