Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ def task_run(args, dag=None):

task = dag.get_task(task_id=args.task_id)
ti = TaskInstance(task, args.execution_date)
ti.refresh_from_db()

ti.init_run_context(raw=args.raw)

hostname = get_hostname()
Expand Down
29 changes: 23 additions & 6 deletions airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,26 @@ def __init__(
self.run_backwards = run_backwards
super().__init__(*args, **kwargs)

def _update_counters(self, ti_status):
@provide_session
def _update_counters(self, ti_status, session=None):
"""
Updates the counters per state of the tasks that were running. Can re-add
to tasks to run in case required.

:param ti_status: the internal status of the backfill job tasks
:type ti_status: BackfillJob._DagRunTaskStatus
"""
for key, ti in list(ti_status.running.items()):
ti.refresh_from_db()
tis_to_be_scheduled = []
refreshed_tis = []
TI = TaskInstance

filter_for_tis = TI.filter_for_tis(list(ti_status.running.values()))
if filter_for_tis is not None:
refreshed_tis = session.query(TI).filter(filter_for_tis).all()

for ti in refreshed_tis:
# Here we remake the key by subtracting 1 to match in memory information
key = (ti.dag_id, ti.task_id, ti.execution_date, max(1, ti.try_number - 1))
if ti.state == State.SUCCESS:
ti_status.succeeded.add(key)
self.log.debug("Task instance %s succeeded. Don't rerun.", ti)
Expand Down Expand Up @@ -223,10 +233,17 @@ def _update_counters(self, ti_status):
"reaching concurrency limits. Re-adding task to queue.",
ti
)
ti.set_state(State.SCHEDULED)
tis_to_be_scheduled.append(ti)
ti_status.running.pop(key)
ti_status.to_run[key] = ti

# Batch schedule of task instances
if tis_to_be_scheduled:
filter_for_tis = TI.filter_for_tis(tis_to_be_scheduled)
session.query(TI).filter(filter_for_tis).update(
values={TI.state: State.SCHEDULED}, synchronize_session=False
)

def _manage_executor_state(self, running):
"""
Checks if the executor agrees with the state of task instances
Expand All @@ -236,6 +253,7 @@ def _manage_executor_state(self, running):
"""
executor = self.executor

# TODO: query all instead of refresh from db
for key, state in list(executor.get_event_buffer().items()):
if key not in running:
self.log.warning(
Expand Down Expand Up @@ -406,7 +424,7 @@ def _process_backfill_task_instances(self,
# waiting for their upstream to finish
@provide_session
def _per_task_process(task, key, ti, session=None):
ti.refresh_from_db()
ti.refresh_from_db(lock_for_update=True, session=session)

task = self.dag.get_task(ti.task_id, include_subdags=True)
ti.task = task
Expand Down Expand Up @@ -470,7 +488,6 @@ def _per_task_process(task, key, ti, session=None):
ignore_task_deps=self.ignore_task_deps,
flag_upstream_failed=True)

ti.refresh_from_db(lock_for_update=True, session=session)
# Is the task runnable? -- then run it
# the dependency checker can change states of tis
if ti.are_dependencies_met(
Expand Down
21 changes: 10 additions & 11 deletions airflow/jobs/base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from time import sleep
from typing import Optional

from sqlalchemy import Column, Index, Integer, String, and_, or_
from sqlalchemy import Column, Index, Integer, String, and_
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import make_transient

Expand Down Expand Up @@ -280,19 +280,18 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None):
return []

def query(result, items):
filter_for_tis = ([and_(TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.execution_date == ti.execution_date)
for ti in items])
reset_tis = (
session
.query(TI)
.filter(or_(*filter_for_tis), TI.state.in_(resettable_states))
.with_for_update()
.all())
if not items:
return result

filter_for_tis = TI.filter_for_tis(items)
reset_tis = session.query(TI).filter(
filter_for_tis, TI.state.in_(resettable_states)
).with_for_update().all()

for ti in reset_tis:
ti.state = State.NONE
session.merge(ti)

return result + reset_tis

reset_tis = helpers.reduce_in_chunks(query,
Expand Down
46 changes: 26 additions & 20 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def _process_dags(self, dagbag, dags, tis_out):
:param dagbag: a collection of DAGs to process
:type dagbag: airflow.models.DagBag
:param dags: the DAGs from the DagBag to process
:type dags: airflow.models.DAG
:type dags: List[airflow.models.DAG]
:param tis_out: A list to add generated TaskInstance objects
:type tis_out: list[TaskInstance]
:rtype: None
Expand Down Expand Up @@ -796,15 +796,22 @@ def process_file(self, file_path, zombies, pickle_dags=False, session=None):
# process and due to some unusual behavior. (empty() incorrectly
# returns true as described in https://bugs.python.org/issue23582 )
ti_keys_to_schedule = []
refreshed_tis = []

self._process_dags(dagbag, dags, ti_keys_to_schedule)

for ti_key in ti_keys_to_schedule:
dag = dagbag.dags[ti_key[0]]
task = dag.get_task(ti_key[1])
ti = models.TaskInstance(task, ti_key[2])
# Refresh all task instances that will be scheduled
TI = models.TaskInstance
filter_for_tis = TI.filter_for_tis(ti_keys_to_schedule)

if filter_for_tis is not None:
refreshed_tis = session.query(TI).filter(filter_for_tis).with_for_update().all()

for ti in refreshed_tis:
# Add task to task instance
dag = dagbag.dags[ti.key[0]]
ti.task = dag.get_task(ti.key[1])

ti.refresh_from_db(session=session, lock_for_update=True)
# We check only deps needed to set TI to SCHEDULED state here.
# Deps needed to set TI to QUEUED state will be batch checked later
# by the scheduler for better performance.
Expand Down Expand Up @@ -994,8 +1001,7 @@ def _change_state_for_tis_without_dagrun(self,
models.TaskInstance.task_id == subq.c.task_id,
models.TaskInstance.execution_date ==
subq.c.execution_date)) \
.update({models.TaskInstance.state: new_state},
synchronize_session=False)
.update({models.TaskInstance.state: new_state}, synchronize_session=False)
session.commit()

if tis_changed > 0:
Expand Down Expand Up @@ -1270,20 +1276,17 @@ def _change_state_for_executable_task_instances(self, task_instances,
return []

# set TIs to queued state
for task_instance in tis_to_set_to_queued:
task_instance.state = State.QUEUED
task_instance.queued_dttm = timezone.utcnow()
session.merge(task_instance)
filter_for_tis = TI.filter_for_tis(tis_to_set_to_queued)
session.query(TI).filter(filter_for_tis).update(
{TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow()}, synchronize_session=False
)
session.commit()

# Generate a list of SimpleTaskInstance for the use of queuing
# them in the executor.
simple_task_instances = [SimpleTaskInstance(ti) for ti in
tis_to_set_to_queued]
simple_task_instances = [SimpleTaskInstance(ti) for ti in tis_to_set_to_queued]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nuclearpinguin I believe that this should be before session.commit().

Session.commit() should flush the session, and then sqlalchemy will need to requery each TI individually to reconstruct the STI


task_instance_str = "\n\t".join(
[repr(x) for x in tis_to_set_to_queued])

session.commit()
task_instance_str = "\n\t".join([repr(x) for x in tis_to_set_to_queued])
self.log.info("Setting the following %s tasks to queued state:\n\t%s",
len(tis_to_set_to_queued), task_instance_str)
return simple_task_instances
Expand Down Expand Up @@ -1398,9 +1401,12 @@ def _change_state_for_tasks_failed_to_execute(self, session):
return

# set TIs to queued state
filter_for_tis = TI.filter_for_tis(tis_to_set_to_scheduled)
session.query(TI).filter(filter_for_tis).update(
{TI.state: State.SCHEDULED, TI.queued_dttm: None}, synchronize_session=False
)

for task_instance in tis_to_set_to_scheduled:
task_instance.state = State.SCHEDULED
task_instance.queued_dttm = None
self.executor.queued_tasks.pop(task_instance.key)

task_instance_str = "\n\t".join(
Expand Down
20 changes: 17 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from airflow.exceptions import AirflowException
from airflow.models.base import ID_LEN, Base
from airflow.models.taskinstance import TaskInstance as TI
from airflow.stats import Stats
from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, DepContext
from airflow.utils import timezone
Expand Down Expand Up @@ -336,18 +337,31 @@ def update_state(self, session=None):
return ready_tis

def _get_ready_tis(self, scheduleable_tasks, finished_tasks, session):
old_states = {}
ready_tis = []
changed_tis = False

if not scheduleable_tasks:
return ready_tis, changed_tis

# Check dependencies
for st in scheduleable_tasks:
st_old_state = st.state
old_state = st.state
if st.are_dependencies_met(
dep_context=DepContext(
flag_upstream_failed=True,
finished_tasks=finished_tasks),
session=session):
ready_tis.append(st)
elif st_old_state != st.current_state(session=session):
changed_tis = True
else:
old_states[st.key] = old_state

# Check if any ti changed state
tis_filter = TI.filter_for_tis(old_states.keys())
if tis_filter is not None:
fresh_tis = session.query(TI).filter(tis_filter).all()
changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis)

return ready_tis, changed_tis

def _are_premature_tis(self, unfinished_tasks, finished_tasks, session):
Expand Down
26 changes: 25 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
import dill
import lazy_object_proxy
import pendulum
from sqlalchemy import Column, Float, Index, Integer, PickleType, String, func
from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_
from sqlalchemy.orm import reconstructor
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList

from airflow import settings
from airflow.configuration import conf
Expand Down Expand Up @@ -1524,6 +1525,29 @@ def init_run_context(self, raw=False):
self.raw = raw
self._set_context(self)

@staticmethod
def filter_for_tis(
tis: Iterable[Union["TaskInstance", TaskInstanceKeyType]]
) -> Optional[BooleanClauseList]:
"""Returns SQLAlchemy filter to query selected task instances"""
TI = TaskInstance
if not tis:
return None
if all(isinstance(t, tuple) for t in tis):
filter_for_tis = ([and_(TI.dag_id == dag_id,
TI.task_id == task_id,
TI.execution_date == execution_date)
for dag_id, task_id, execution_date, _ in tis])
return or_(*filter_for_tis)
if all(isinstance(t, TaskInstance) for t in tis):
filter_for_tis = ([and_(TI.dag_id == ti.dag_id, # type: ignore
TI.task_id == ti.task_id, # type: ignore
TI.execution_date == ti.execution_date) # type: ignore
for ti in tis])
return or_(*filter_for_tis)

raise TypeError("All elements must have the same type: `TaskInstance` or `TaskInstanceKey`.")


# State of the task instance.
# Stores string version of the task state.
Expand Down
33 changes: 29 additions & 4 deletions airflow/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,73 @@
import json
import logging
import os
import time
import traceback

import pendulum
from dateutil import relativedelta
from sqlalchemy import event, exc
from sqlalchemy.types import DateTime, Text, TypeDecorator

from airflow.configuration import conf

log = logging.getLogger(__name__)

utc = pendulum.timezone('UTC')


def setup_event_handlers(engine):
"""
Setups event handlers.
"""
# pylint: disable=unused-argument
@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record): # pylint: disable=unused-argument
def connect(dbapi_connection, connection_record):
connection_record.info['pid'] = os.getpid()

if engine.dialect.name == "sqlite":
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record): # pylint: disable=unused-argument
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()

# this ensures sanity in mysql when storing datetimes (not required for postgres)
if engine.dialect.name == "mysql":
@event.listens_for(engine, "connect")
def set_mysql_timezone(dbapi_connection, connection_record): # pylint: disable=unused-argument
def set_mysql_timezone(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("SET time_zone = '+00:00'")
cursor.close()

@event.listens_for(engine, "checkout")
def checkout(dbapi_connection, connection_record, connection_proxy): # pylint: disable=unused-argument
def checkout(dbapi_connection, connection_record, connection_proxy):
pid = os.getpid()
if connection_record.info['pid'] != pid:
connection_record.connection = connection_proxy.connection = None
raise exc.DisconnectionError(
"Connection record belongs to pid {}, "
"attempting to check out in pid {}".format(connection_record.info['pid'], pid)
)
if conf.getboolean('debug', 'sqlalchemy_stats', fallback=False):
@event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
conn.info.setdefault('query_start_time', []).append(time.time())

@event.listens_for(engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - conn.info['query_start_time'].pop()
file_name = [
f"'{f.name}':{f.filename}:{f.lineno}" for f
in traceback.extract_stack() if 'sqlalchemy' not in f.filename][-1]
stack = [f for f in traceback.extract_stack() if 'sqlalchemy' not in f.filename]
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:])
conn.info.setdefault('query_start_time', []).append(time.monotonic())
log.info("@SQLALCHEMY %s |$ %s |$ %s |$ %s ",
total, file_name, stack_info, statement.replace("\n", " ")
)

# pylint: enable=unused-argument


class UtcDateTime(TypeDecorator):
Expand Down
Loading