Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grid data: do not load all mapped instances #23813

Merged
merged 3 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions airflow/www/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from airflow import models
from airflow.models import errors
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.code_utils import get_python_source
Expand Down Expand Up @@ -127,13 +128,23 @@ def get_mapped_summary(parent_instance, task_instances):
}


def encode_ti(
task_instance: Optional[TaskInstance], is_mapped: Optional[bool], session: Optional[Session]
) -> Optional[Dict[str, Any]]:
def get_task_summary(dag_run: DagRun, task, session: Session) -> Optional[Dict[str, Any]]:
task_instance = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == task.dag_id,
TaskInstance.run_id == dag_run.run_id,
TaskInstance.task_id == task.task_id,
# Only get normal task instances or the first mapped task
TaskInstance.map_index <= 0,
bbovenzi marked this conversation as resolved.
Show resolved Hide resolved
)
.first()
)

if not task_instance:
return None

if is_mapped:
if task_instance.map_index > -1:
return get_mapped_summary(task_instance, task_instances=get_mapped_instances(task_instance, session))

try_count = (
Expand Down
34 changes: 8 additions & 26 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,15 @@ def _safe_parse_datetime(v):
abort(400, f"Invalid datetime: {v!r}")


def task_group_to_grid(task_item_or_group, dag, dag_runs, tis, session):
def task_group_to_grid(task_item_or_group, dag, dag_runs, session):
"""
Create a nested dict representation of this TaskGroup and its children used to construct
the Graph.
"""
if isinstance(task_item_or_group, AbstractOperator):
return {
'id': task_item_or_group.task_id,
'instances': [
wwwutils.encode_ti(ti, task_item_or_group.is_mapped, session)
for ti in tis
if ti.task_id == task_item_or_group.task_id
],
'instances': [wwwutils.get_task_summary(dr, task_item_or_group, session) for dr in dag_runs],
'label': task_item_or_group.label,
'extra_links': task_item_or_group.extra_links,
'is_mapped': task_item_or_group.is_mapped,
Expand All @@ -271,21 +267,15 @@ def task_group_to_grid(task_item_or_group, dag, dag_runs, tis, session):
# Task Group
task_group = task_item_or_group

children = [
task_group_to_grid(child, dag, dag_runs, tis, session) for child in task_group.topological_sort()
]
children = [task_group_to_grid(child, dag, dag_runs, session) for child in task_group.topological_sort()]

def get_summary(dag_run, children):
child_instances = [child['instances'] for child in children if 'instances' in child]
child_instances = [item for sublist in child_instances for item in sublist]

children_start_dates = [
item['start_date'] for item in child_instances if item['run_id'] == dag_run.run_id
]
children_end_dates = [
item['end_date'] for item in child_instances if item['run_id'] == dag_run.run_id
]
children_states = [item['state'] for item in child_instances if item['run_id'] == dag_run.run_id]
children_start_dates = [item['start_date'] for item in child_instances if item]
children_end_dates = [item['end_date'] for item in child_instances if item]
children_states = [item['state'] for item in child_instances if item]

group_state = None
for state in wwwutils.priority:
Expand Down Expand Up @@ -2642,12 +2632,8 @@ def grid(self, dag_id, session=None):
else:
external_log_name = None

min_date = min(dag_run_dates, default=None)

tis = dag.get_task_instances(start_date=min_date, end_date=base_date, session=session)

data = {
'groups': task_group_to_grid(dag.task_group, dag, dag_runs, tis, session),
'groups': task_group_to_grid(dag.task_group, dag, dag_runs, session),
'dag_runs': encoded_runs,
}

Expand Down Expand Up @@ -2675,7 +2661,6 @@ def grid(self, dag_id, session=None):
dag_model=dag_model,
auto_refresh_interval=conf.getint('webserver', 'auto_refresh_interval'),
default_dag_run_display_number=default_dag_run_display_number,
task_instances=tis,
filters_drop_down_values=htmlsafe_json_dumps(
{
"taskStates": [state.value for state in TaskInstanceState],
Expand Down Expand Up @@ -3542,11 +3527,8 @@ def grid_data(self):
dag_runs = query.order_by(DagRun.execution_date.desc()).limit(num_runs).all()
dag_runs.reverse()
encoded_runs = [wwwutils.encode_dag_run(dr) for dr in dag_runs]
dag_run_dates = {dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs}
min_date = min(dag_run_dates, default=None)
tis = dag.get_task_instances(start_date=min_date, end_date=base_date, session=session)
data = {
'groups': task_group_to_grid(dag.task_group, dag, dag_runs, tis, session),
'groups': task_group_to_grid(dag.task_group, dag, dag_runs, session),
'dag_runs': encoded_runs,
}

Expand Down