Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions airflow/bin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import psutil

import airflow
from airflow import jobs, settings
from airflow import jobs, settings, utils
from airflow import configuration as conf
from airflow.exceptions import AirflowException
from airflow.executors import DEFAULT_EXECUTOR
Expand Down Expand Up @@ -503,6 +503,44 @@ def clear(args):
confirm_prompt=not args.no_confirm,
include_subdags=not args.exclude_subdags)

def mark_success(args):
dag = get_dag(args)
runs = dag.get_dagruns(start_date=args.start_date,
end_date=args.end_date,
include_subdags=not args.exclude_subdags)

tis = []
for run in runs:
tis += run.mark_success(task_regex=args.task_regex,
include_downstream=args.downstream,
include_upstream=args.upstream,
dry_run=True)

if len(tis) == 0:
print("No task instances to mark as successful")
return
if len(tis) > 1000:
print("Too many tasks (>1000)")
return

do_it = True
if not args.no_confirm:
ti_list = "\n".join([str(t) for t in tis])
count = len(tis)
question = (
"You are about to mark these {count} tasks success:\n"
"{ti_list}\n\n"
"Are you sure? (yes/no): ").format(**locals())
do_it = utils.helpers.ask_yesno(question)
if do_it:
count = 0
for run in runs:
count += run.mark_success(task_regex=args.task_regex,
include_downstream=args.downstream,
include_upstream=args.upstream)
print("{} task instances have been marked success".format(count))
else:
print("Bail. Nothing was marked success.")

def restart_workers(gunicorn_master_proc, num_workers_expected):
"""
Expand Down Expand Up @@ -942,7 +980,7 @@ class CLIFactory(object):
# list_dags
'report': Arg(
("-r", "--report"), "Show DagBag loading report", "store_true"),
# clear
# clear, mark_success
'upstream': Arg(
("-u", "--upstream"), "Include upstream tasks", "store_true"),
'only_failed': Arg(
Expand Down Expand Up @@ -1132,6 +1170,12 @@ class CLIFactory(object):
'dag_id', 'task_regex', 'start_date', 'end_date', 'subdir',
'upstream', 'downstream', 'no_confirm', 'only_failed',
'only_running', 'exclude_subdags'),
}, {
'func': mark_success,
'help': "Mark success a set of task instances",
'args': (
'dag_id', 'task_regex', 'start_date', 'end_date', 'subdir',
'upstream', 'downstream', 'no_confirm', 'exclude_subdags'),
}, {
'func': pause,
'help': "Pause a DAG",
Expand Down
82 changes: 77 additions & 5 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,15 +1059,15 @@ def evaluate_trigger_rule(self, successes, skipped, failed,
path to add the feature
:type flag_upstream_failed: boolean
:param successes: Number of successful upstream tasks
:type successes: boolean
:type successes: int
:param skipped: Number of skipped upstream tasks
:type skipped: boolean
:type skipped: int
:param failed: Number of failed upstream tasks
:type failed: boolean
:type failed: int
:param upstream_failed: Number of upstream_failed upstream tasks
:type upstream_failed: boolean
:type upstream_failed: int
:param done: Number of completed upstream tasks
:type done: boolean
:type done: int
"""
TR = TriggerRule

Expand Down Expand Up @@ -3365,6 +3365,30 @@ def deactivate_stale_dags(expiration_date, session=None):
session.merge(dag)
session.commit()

@provide_session
def get_dagruns(self, start_date=None, end_date=None, session=None, include_subdags=False):
"""
Return a list of dag runs from this dag with execution date between
start_date and end_date
"""
DR = DagRun
dr = session.query(DR).filter(DR.dag_id == self.dag_id)
if start_date:
dr = dr.filter(DR.execution_date >= start_date)
if end_date:
dr = dr.filter(DR.execution_date <= end_date)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It annoys me that this is inclusive on end_date. If other CLI params behave this way, then OK. If not, would prefer exclusive end_date.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inclusive end_date has been used so far, but use case for cli is only airflow/clear, so seems ok to change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to keep it as is for now since 'airflow clear' cli is probably used in prod.

dagruns = dr.all()
for run in dagruns:
run.dag = self

if include_subdags:
# recursively get all subdags' dagruns
for subdag in self.subdags:
dagruns += subdag.get_dagruns(start_date=start_date,
end_date=end_date,
include_subdags=include_subdags)
return dagruns


class Chart(Base):
__tablename__ = "chart"
Expand Down Expand Up @@ -3919,6 +3943,54 @@ def is_backfill(self):

return False

@provide_session
def mark_success(self,
task_regex=None,
include_downstream=False,
include_upstream=True,
dry_run=False,
session=None):
"""
Mark success a list of task instances associated with the current dagrun
for a specific date range.

:param task_regex: regex pattern that task ids match to
:type task_regex string
:param include_downstream: set to true to include downstream tasks of matched tasks
:type include_downstream boolean
:param include_upstream: set to true to include upstream tasks of matched tasks
:type include_upstream boolean
:param dry_run: set to true to return the TIs without actually marking them as success
:type dry_run: boolean
:param session: database session
:type session: Session
"""
states = State.all()
states.remove(State.SUCCESS)
tis = self.get_task_instances(state=states, session=session)
if task_regex:
dag = self.get_dag()
regex_match = [t for t in dag.tasks if re.findall(task_regex, t.task_id)]
include = []
for task in regex_match:
include.append(task.task_id)
if include_downstream:
include += [t.task_id for t in task.get_flat_relatives(upstream=False)]
if include_upstream:
include += [t.task_id for t in task.get_flat_relatives(upstream=True)]
tis = [ti for ti in tis if ti.task_id in include and ti.state != State.SUCCESS]

if dry_run:
return tis

session.expunge_all()
for ti in tis:
ti.state = State.SUCCESS
session.merge(ti)
session.commit()
self.update_state()
return len(tis)


class Pool(Base):
__tablename__ = "slot_pool"
Expand Down
5 changes: 5 additions & 0 deletions airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def finished(cls):
cls.SUCCESS,
cls.SHUTDOWN,
cls.FAILED,
cls.UPSTREAM_FAILED,
cls.SKIPPED,
]

Expand All @@ -108,3 +109,7 @@ def unfinished(cls):
cls.RUNNING,
cls.UP_FOR_RETRY
]

@classmethod
def all(cls):
return cls.finished() + cls.unfinished()
112 changes: 23 additions & 89 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from datetime import datetime, timedelta
import dateutil.parser
import copy
import re
from itertools import chain, product
import json
from lxml import html
Expand Down Expand Up @@ -982,7 +983,6 @@ def success(self):
task_id = request.args.get('task_id')
origin = request.args.get('origin')
dag = dagbag.get_dag(dag_id)
task = dag.get_task(task_id)

execution_date = request.args.get('execution_date')
execution_date = dateutil.parser.parse(execution_date)
Expand All @@ -992,102 +992,36 @@ def success(self):
future = request.args.get('future') == "true"
past = request.args.get('past') == "true"
recursive = request.args.get('recursive') == "true"
MAX_PERIODS = 1000

# Flagging tasks as successful
session = settings.Session()
task_ids = [task_id]
dag_ids = [dag_id]
task_id_to_dag = {
task_id: dag
}
end_date = ((dag.latest_execution_date or datetime.now())
if future else execution_date)

if 'start_date' in dag.default_args:
start_date = dag.default_args['start_date']
elif dag.start_date:
start_date = dag.start_date
else:
start_date = execution_date

start_date = execution_date if not past else start_date

if recursive:
recurse_tasks(task, task_ids, dag_ids, task_id_to_dag)
start_date = execution_date if not past else None
end_date = execution_date if not future else None

if downstream:
relatives = task.get_flat_relatives(upstream=False)
task_ids += [t.task_id for t in relatives]
if recursive:
recurse_tasks(relatives, task_ids, dag_ids, task_id_to_dag)
if upstream:
relatives = task.get_flat_relatives(upstream=False)
task_ids += [t.task_id for t in relatives]
if recursive:
recurse_tasks(relatives, task_ids, dag_ids, task_id_to_dag)
TI = models.TaskInstance

if dag.schedule_interval == '@once':
dates = [start_date]
else:
dates = dag.date_range(start_date, end_date=end_date)

tis = session.query(TI).filter(
TI.dag_id.in_(dag_ids),
TI.execution_date.in_(dates),
TI.task_id.in_(task_ids)).all()
tis_to_change = session.query(TI).filter(
TI.dag_id.in_(dag_ids),
TI.execution_date.in_(dates),
TI.task_id.in_(task_ids),
TI.state != State.SUCCESS).all()
tasks = list(product(task_ids, dates))
tis_to_create = list(
set(tasks) -
set([(ti.task_id, ti.execution_date) for ti in tis]))

tis_all_altered = list(chain(
[(ti.task_id, ti.execution_date) for ti in tis_to_change],
tis_to_create))

if len(tis_all_altered) > MAX_PERIODS:
flash("Too many tasks at once (>{0})".format(
MAX_PERIODS), 'error')
return redirect(origin)
runs = dag.get_dagruns(start_date=start_date,
end_date=end_date,
include_subdags=recursive)

if confirmed:
for ti in tis_to_change:
ti.state = State.SUCCESS
session.commit()

for task_id, task_execution_date in tis_to_create:
ti = TI(
task=task_id_to_dag[task_id].get_task(task_id),
execution_date=task_execution_date,
state=State.SUCCESS)
session.add(ti)
session.commit()

session.commit()
session.close()
flash("Marked success on {} task instances".format(
len(tis_all_altered)))

count = 0
for run in runs:
count += run.mark_success(task_regex=re.escape(task_id),
include_downstream=downstream,
include_upstream=upstream)
flash("{0} task instances have been marked success".format(count))
return redirect(origin)
else:
if not tis_all_altered:
tis = []
for run in runs:
tis += run.mark_success(task_regex=re.escape(task_id),
include_downstream=downstream,
include_upstream=upstream,
dry_run=True)
if len(tis) == 0:
flash("No task instances to mark as successful", 'error')
response = redirect(origin)
elif len(tis) > 1000:
flash("Too many tasks (>1000)", 'error')
response = redirect(origin)
else:
tis = []
for task_id, task_execution_date in tis_all_altered:
tis.append(TI(
task=task_id_to_dag[task_id].get_task(task_id),
execution_date=task_execution_date,
state=State.SUCCESS))
details = "\n".join([str(t) for t in tis])

details = "\n".join([str(ti) for ti in tis])
response = self.render(
'airflow/confirm.html',
message=(
Expand Down
Loading