Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 13 additions & 83 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from flask import g
from marshmallow import ValidationError
from sqlalchemy import and_, or_, select
from sqlalchemy import or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload

Expand Down Expand Up @@ -48,7 +48,6 @@
from airflow.api_connexion.security import get_readable_dags
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.exceptions import TaskNotFound
from airflow.models import SlaMiss
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances
from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH
Expand Down Expand Up @@ -84,27 +83,18 @@ def get_task_instance(
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
.join(TI.dag_run)
.outerjoin(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.execution_date == DR.execution_date,
SlaMiss.task_id == TI.task_id,
),
)
.add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)

try:
task_instance = session.execute(query).one_or_none()
task_instance = session.scalar(query)
except MultipleResultsFound:
raise NotFound(
"Task instance not found", detail="Task instance is mapped, add the map_index value to the URL"
)
if task_instance is None:
raise NotFound("Task instance not found")
if task_instance[0].map_index != -1:
if task_instance.map_index != -1:
raise NotFound(
"Task instance not found", detail="Task instance is mapped, add the map_index value to the URL"
)
Expand All @@ -127,18 +117,9 @@ def get_mapped_task_instance(
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id, TI.map_index == map_index)
.join(TI.dag_run)
.outerjoin(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.execution_date == DR.execution_date,
SlaMiss.task_id == TI.task_id,
),
)
.add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)
task_instance = session.execute(query).one_or_none()
task_instance = session.scalar(query)

if task_instance is None:
raise NotFound("Task instance not found")
Expand Down Expand Up @@ -232,28 +213,13 @@ def get_mapped_task_instances(
# Count elements before joining extra columns
total_entries = get_query_count(base_query, session=session)

# Add SLA miss
entry_query = (
base_query.outerjoin(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.task_id == TI.task_id,
SlaMiss.execution_date == DR.execution_date,
),
)
.add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)

try:
order_by_params = _get_order_by_params(order_by)
entry_query = entry_query.order_by(*order_by_params)
entry_query = base_query.order_by(*order_by_params)
except _UnsupportedOrderBy as e:
raise BadRequest(detail=f"Ordering with {e.order_by!r} is not supported")

# using execute because we want the SlaMiss entity. Scalars don't return None for missing entities
task_instances = session.execute(entry_query.offset(offset).limit(limit)).all()
task_instances = session.scalars(entry_query.offset(offset).limit(limit))
return task_instance_collection_schema.dump(
TaskInstanceCollection(task_instances=task_instances, total_entries=total_entries)
)
Expand Down Expand Up @@ -384,28 +350,13 @@ def get_task_instances(
# Count elements before joining extra columns
total_entries = get_query_count(base_query, session=session)

# Add join
entry_query = (
base_query.outerjoin(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.task_id == TI.task_id,
SlaMiss.execution_date == DR.execution_date,
),
)
.add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)

try:
order_by_params = _get_order_by_params(order_by)
entry_query = entry_query.order_by(*order_by_params)
entry_query = base_query.order_by(*order_by_params)
except _UnsupportedOrderBy as e:
raise BadRequest(detail=f"Ordering with {e.order_by!r} is not supported")

# using execute because we want the SlaMiss entity. Scalars don't return None for missing entities
task_instances = session.execute(entry_query.offset(offset).limit(limit)).all()
task_instances = session.scalars(entry_query.offset(offset).limit(limit))
return task_instance_collection_schema.dump(
TaskInstanceCollection(task_instances=task_instances, total_entries=total_entries)
)
Expand Down Expand Up @@ -463,16 +414,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:

# Count elements before joining extra columns
total_entries = get_query_count(base_query, session=session)
# Add join
base_query = base_query.join(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.task_id == TI.task_id,
SlaMiss.execution_date == DR.execution_date,
),
isouter=True,
).add_columns(SlaMiss)

ti_query = base_query.options(
joinedload(TI.rendered_task_instance_fields), joinedload(TI.task_instance_note)
)
Expand All @@ -483,8 +425,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
except _UnsupportedOrderBy as e:
raise BadRequest(detail=f"Ordering with {e.order_by!r} is not supported")

# using execute because we want the SlaMiss entity. Scalars don't return None for missing entities
task_instances = session.execute(ti_query).all()
task_instances = session.scalars(ti_query)

return task_instance_collection_schema.dump(
TaskInstanceCollection(task_instances=task_instances, total_entries=total_entries)
Expand Down Expand Up @@ -690,15 +631,6 @@ def set_task_instance_note(
select(TI)
.where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
.join(TI.dag_run)
.outerjoin(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.execution_date == DR.execution_date,
SlaMiss.task_id == TI.task_id,
),
)
.add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)
if map_index == -1:
Expand All @@ -707,25 +639,23 @@ def set_task_instance_note(
query = query.where(TI.map_index == map_index)

try:
result = session.execute(query).one_or_none()
ti = session.scalar(query)
except MultipleResultsFound:
raise NotFound(
"Task instance not found", detail="Task instance is mapped, add the map_index value to the URL"
)
if result is None:
if ti is None:
error_message = f"Task Instance not found for dag_id={dag_id}, run_id={dag_run_id}, task_id={task_id}"
raise NotFound(error_message)

ti, sla_miss = result

current_user_id = get_auth_manager().get_user_id()
if ti.task_instance_note is None:
ti.note = (new_note, current_user_id)
else:
ti.task_instance_note.content = new_note
ti.task_instance_note.user_id = current_user_id
session.commit()
return task_instance_schema.dump((ti, sla_miss))
return task_instance_schema.dump(ti)


@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE)
Expand Down
38 changes: 0 additions & 38 deletions airflow/api_connexion/schemas/sla_miss_schema.py

This file was deleted.

21 changes: 5 additions & 16 deletions airflow/api_connexion/schemas/task_instance_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, NamedTuple
from typing import NamedTuple

from marshmallow import Schema, ValidationError, fields, validate, validates_schema
from marshmallow.utils import get_value
Expand All @@ -26,16 +26,12 @@
from airflow.api_connexion.schemas.common_schema import JsonObjectField
from airflow.api_connexion.schemas.enum_schemas import TaskInstanceStateField
from airflow.api_connexion.schemas.job_schema import JobSchema
from airflow.api_connexion.schemas.sla_miss_schema import SlaMissSchema
from airflow.api_connexion.schemas.trigger_schema import TriggerSchema
from airflow.models import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.utils.helpers import exactly_one
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models import SlaMiss


class TaskInstanceSchema(SQLAlchemySchema):
"""Task instance schema."""
Expand Down Expand Up @@ -69,22 +65,15 @@ class Meta:
executor = auto_field()
executor_config = auto_field()
note = auto_field()
sla_miss = fields.Nested(SlaMissSchema, dump_default=None)
rendered_map_index = auto_field()
rendered_fields = JsonObjectField(dump_default={})
trigger = fields.Nested(TriggerSchema)
triggerer_job = fields.Nested(JobSchema)

def get_attribute(self, obj, attr, default):
if attr == "sla_miss":
# Object is a tuple of task_instance and slamiss
# and the get_value expects a dict with key, value
# corresponding to the attr.
slamiss_instance = {"sla_miss": obj[1]}
return get_value(slamiss_instance, attr, default)
elif attr == "rendered_fields":
return get_value(obj[0], "rendered_task_instance_fields.rendered_fields", default)
return get_value(obj[0], attr, default)
if attr == "rendered_fields":
return get_value(obj, "rendered_task_instance_fields.rendered_fields", default)
return get_value(obj, attr, default)


class TaskInstanceHistorySchema(SQLAlchemySchema):
Expand Down Expand Up @@ -122,7 +111,7 @@ class Meta:
class TaskInstanceCollection(NamedTuple):
"""List of task instances with metadata."""

task_instances: list[tuple[TaskInstance, SlaMiss | None]]
task_instances: list[TaskInstance | None]
total_entries: int


Expand Down
1 change: 0 additions & 1 deletion airflow/example_dags/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
# 'priority_weight': 10,
# 'end_date': datetime(2016, 1, 1),
# 'wait_for_downstream': False,
# 'sla': timedelta(hours=2),
# 'execution_timeout': timedelta(seconds=300),
# 'on_failure_callback': some_function, # or list of functions
# 'on_success_callback': some_other_function, # or list of functions
Expand Down
3 changes: 0 additions & 3 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"Pool",
"RenderedTaskInstanceFields",
"SkipMixin",
"SlaMiss",
"TaskFail",
"TaskInstance",
"TaskReschedule",
Expand Down Expand Up @@ -104,7 +103,6 @@ def __getattr__(name):
"Pool": "airflow.models.pool",
"RenderedTaskInstanceFields": "airflow.models.renderedtifields",
"SkipMixin": "airflow.models.skipmixin",
"SlaMiss": "airflow.models.slamiss",
"TaskFail": "airflow.models.taskfail",
"TaskInstance": "airflow.models.taskinstance",
"TaskReschedule": "airflow.models.taskreschedule",
Expand Down Expand Up @@ -134,7 +132,6 @@ def __getattr__(name):
from airflow.models.pool import Pool
from airflow.models.renderedtifields import RenderedTaskInstanceFields
from airflow.models.skipmixin import SkipMixin
from airflow.models.slamiss import SlaMiss
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskinstancehistory import TaskInstanceHistory
Expand Down
46 changes: 0 additions & 46 deletions airflow/models/slamiss.py

This file was deleted.

Loading