Skip to content

Commit

Permalink
[Airflow 1332] Split logs based on try number
Browse files Browse the repository at this point in the history
This PR splits logs based on try number and add
tabs to display different task instance tries.

**Note this PR is a temporary change for
separating task attempts. The code in this PR will
be refactored in the future. Please refer to apache#2422
for Airflow logging abstractions redesign.**

Testing:
1. Added unit tests.
2. Tested on localhost.
3. Tested on production environment with S3 remote
storage, MySQL database, Redis, one Airflow
scheduler and two airflow workers.

Closes apache#2383 from AllisonWang/allison--add-task-
attempt
  • Loading branch information
allisonwang authored and aoen committed Jul 21, 2017
1 parent b9576d5 commit b49986c
Show file tree
Hide file tree
Showing 11 changed files with 503 additions and 195 deletions.
81 changes: 43 additions & 38 deletions airflow/bin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from sqlalchemy import func
from sqlalchemy.orm import exc


api.load_auth()
api_module = import_module(conf.get('cli', 'api_client'))
api_client = api_module.Client(api_base_url=conf.get('cli', 'endpoint_url'),
Expand Down Expand Up @@ -316,7 +315,7 @@ def run(args, dag=None):
# Load custom airflow config
if args.cfg_path:
with open(args.cfg_path, 'r') as conf_file:
conf_dict = json.load(conf_file)
conf_dict = json.load(conf_file)

if os.path.exists(args.cfg_path):
os.remove(args.cfg_path)
Expand All @@ -327,6 +326,21 @@ def run(args, dag=None):
settings.configure_vars()
settings.configure_orm()

if not args.pickle and not dag:
dag = get_dag(args)
elif not dag:
session = settings.Session()
logging.info('Loading pickle id {args.pickle}'.format(args=args))
dag_pickle = session.query(
DagPickle).filter(DagPickle.id == args.pickle).first()
if not dag_pickle:
raise AirflowException("Who hid the pickle!? [missing pickle]")
dag = dag_pickle.pickle

task = dag.get_task(task_id=args.task_id)
ti = TaskInstance(task, args.execution_date)
ti.refresh_from_db()

logging.root.handlers = []
if args.raw:
# Output to STDOUT for the parent process to read and log
Expand All @@ -350,19 +364,23 @@ def run(args, dag=None):
# writable by both users, then it's possible that re-running a task
# via the UI (or vice versa) results in a permission error as the task
# tries to write to a log file created by the other user.
try_number = ti.try_number
log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args)
log_relative_dir = logging_utils.get_log_directory(args.dag_id, args.task_id,
args.execution_date)
directory = os.path.join(log_base, log_relative_dir)
# Create the log file and give it group writable permissions
# TODO(aoen): Make log dirs and logs globally readable for now since the SubDag
# operator is not compatible with impersonation (e.g. if a Celery executor is used
# for a SubDag operator and the SubDag operator has a different owner than the
# parent DAG)
if not os.path.exists(directory):
if not os.path.isdir(directory):
# Create the directory as globally writable using custom mkdirs
# as os.makedirs doesn't set mode properly.
mkdirs(directory, 0o775)
iso = args.execution_date.isoformat()
filename = "{directory}/{iso}".format(**locals())
log_relative = logging_utils.get_log_filename(
args.dag_id, args.task_id, args.execution_date, try_number)
filename = os.path.join(log_base, log_relative)

if not os.path.exists(filename):
open(filename, "a").close()
Expand All @@ -376,21 +394,6 @@ def run(args, dag=None):
hostname = socket.getfqdn()
logging.info("Running on host {}".format(hostname))

if not args.pickle and not dag:
dag = get_dag(args)
elif not dag:
session = settings.Session()
logging.info('Loading pickle id {args.pickle}'.format(**locals()))
dag_pickle = session.query(
DagPickle).filter(DagPickle.id == args.pickle).first()
if not dag_pickle:
raise AirflowException("Who hid the pickle!? [missing pickle]")
dag = dag_pickle.pickle
task = dag.get_task(task_id=args.task_id)

ti = TaskInstance(task, args.execution_date)
ti.refresh_from_db()

if args.local:
print("Logging into: " + filename)
run_job = jobs.LocalTaskJob(
Expand Down Expand Up @@ -424,8 +427,8 @@ def run(args, dag=None):
session.commit()
pickle_id = pickle.id
print((
'Pickled dag {dag} '
'as pickle_id:{pickle_id}').format(**locals()))
'Pickled dag {dag} '
'as pickle_id:{pickle_id}').format(**locals()))
except Exception as e:
print('Could not pickle the DAG')
print(e)
Expand Down Expand Up @@ -475,7 +478,8 @@ def run(args, dag=None):
with open(filename, 'r') as logfile:
log = logfile.read()

remote_log_location = filename.replace(log_base, remote_base)
remote_log_location = os.path.join(remote_base, log_relative)
logging.debug("Uploading to remote log location {}".format(remote_log_location))
# S3
if remote_base.startswith('s3:/'):
logging_utils.S3Log().write(log, remote_log_location)
Expand Down Expand Up @@ -669,10 +673,10 @@ def start_refresh(gunicorn_master_proc):
gunicorn_master_proc.send_signal(signal.SIGTTIN)
excess += 1
wait_until_true(lambda: num_workers_expected + excess ==
get_num_workers_running(gunicorn_master_proc))
get_num_workers_running(gunicorn_master_proc))

wait_until_true(lambda: num_workers_expected ==
get_num_workers_running(gunicorn_master_proc))
get_num_workers_running(gunicorn_master_proc))

while True:
num_workers_running = get_num_workers_running(gunicorn_master_proc)
Expand All @@ -695,7 +699,7 @@ def start_refresh(gunicorn_master_proc):
gunicorn_master_proc.send_signal(signal.SIGTTOU)
excess -= 1
wait_until_true(lambda: num_workers_expected + excess ==
get_num_workers_running(gunicorn_master_proc))
get_num_workers_running(gunicorn_master_proc))

# Start a new worker by asking gunicorn to increase number of workers
elif num_workers_running == num_workers_expected:
Expand Down Expand Up @@ -887,6 +891,7 @@ def serve_logs(filename): # noqa
filename,
mimetype="application/json",
as_attachment=False)

WORKER_LOG_SERVER_PORT = \
int(conf.get('celery', 'WORKER_LOG_SERVER_PORT'))
flask_app.run(
Expand Down Expand Up @@ -947,8 +952,8 @@ def initdb(args): # noqa
def resetdb(args):
print("DB: " + repr(settings.engine.url))
if args.yes or input(
"This will drop existing tables if they exist. "
"Proceed? (y/n)").upper() == "Y":
"This will drop existing tables if they exist. "
"Proceed? (y/n)").upper() == "Y":
logging.basicConfig(level=settings.LOGGING_LEVEL,
format=settings.SIMPLE_LOG_FORMAT)
db_utils.resetdb()
Expand All @@ -966,7 +971,7 @@ def upgradedb(args): # noqa
if not ds_rows:
qry = (
session.query(DagRun.dag_id, DagRun.state, func.count('*'))
.group_by(DagRun.dag_id, DagRun.state)
.group_by(DagRun.dag_id, DagRun.state)
)
for dag_id, state, count in qry:
session.add(DagStat(dag_id=dag_id, state=state, count=count))
Expand Down Expand Up @@ -1065,8 +1070,8 @@ def connections(args):

session = settings.Session()
if not (session
.query(Connection)
.filter(Connection.conn_id == new_conn.conn_id).first()):
.query(Connection)
.filter(Connection.conn_id == new_conn.conn_id).first()):
session.add(new_conn)
session.commit()
msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
Expand Down Expand Up @@ -1168,16 +1173,16 @@ class CLIFactory(object):
'dry_run': Arg(
("-dr", "--dry_run"), "Perform a dry run", "store_true"),
'pid': Arg(
("--pid", ), "PID file location",
("--pid",), "PID file location",
nargs='?'),
'daemon': Arg(
("-D", "--daemon"), "Daemonize instead of running "
"in the foreground",
"store_true"),
'stderr': Arg(
("--stderr", ), "Redirect stderr to this file"),
("--stderr",), "Redirect stderr to this file"),
'stdout': Arg(
("--stdout", ), "Redirect stdout to this file"),
("--stdout",), "Redirect stdout to this file"),
'log_file': Arg(
("-l", "--log-file"), "Location of the log file"),

Expand Down Expand Up @@ -1333,19 +1338,19 @@ class CLIFactory(object):
"Serialized pickle object of the entire dag (used internally)"),
'job_id': Arg(("-j", "--job_id"), argparse.SUPPRESS),
'cfg_path': Arg(
("--cfg_path", ), "Path to config file to use instead of airflow.cfg"),
("--cfg_path",), "Path to config file to use instead of airflow.cfg"),
# webserver
'port': Arg(
("-p", "--port"),
default=conf.get('webserver', 'WEB_SERVER_PORT'),
type=int,
help="The port on which to run the server"),
'ssl_cert': Arg(
("--ssl_cert", ),
("--ssl_cert",),
default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'),
help="Path to the SSL certificate for the webserver"),
'ssl_key': Arg(
("--ssl_key", ),
("--ssl_key",),
default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'),
help="Path to the key to use with the SSL certificate"),
'workers': Arg(
Expand Down
50 changes: 31 additions & 19 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_fernet():
_CONTEXT_MANAGER_DAG = None


def clear_task_instances(tis, session, activate_dag_runs=True):
def clear_task_instances(tis, session, activate_dag_runs=True, dag=None):
"""
Clears a set of task instances, but makes sure the running ones
get killed.
Expand All @@ -119,12 +119,20 @@ def clear_task_instances(tis, session, activate_dag_runs=True):
if ti.job_id:
ti.state = State.SHUTDOWN
job_ids.append(ti.job_id)
# todo: this creates an issue with the webui tests
# elif ti.state != State.REMOVED:
# ti.state = State.NONE
# session.merge(ti)
else:
session.delete(ti)
task_id = ti.task_id
if dag and dag.has_task(task_id):
task = dag.get_task(task_id)
task_retries = task.retries
ti.max_tries = ti.try_number + task_retries
else:
# Ignore errors when updating max_tries if dag is None or
# task not found in dag since database records could be
# outdated. We make max_tries the maximum value of its
# original max_tries or the current task try number.
ti.max_tries = max(ti.max_tries, ti.try_number)
ti.state = State.NONE
session.merge(ti)

if job_ids:
from airflow.jobs import BaseJob as BJ
Expand Down Expand Up @@ -1316,8 +1324,8 @@ def run(
# not 0-indexed lists (i.e. Attempt 1 instead of
# Attempt 0 for the first attempt).
msg = "Starting attempt {attempt} of {total}".format(
attempt=self.try_number % (task.retries + 1) + 1,
total=task.retries + 1)
attempt=self.try_number + 1,
total=self.max_tries + 1)
self.start_date = datetime.now()

dep_context = DepContext(
Expand All @@ -1338,8 +1346,8 @@ def run(
self.state = State.NONE
msg = ("FIXME: Rescheduling due to concurrency limits reached at task "
"runtime. Attempt {attempt} of {total}. State set to NONE.").format(
attempt=self.try_number % (task.retries + 1) + 1,
total=task.retries + 1)
attempt=self.try_number + 1,
total=self.max_tries + 1)
logging.warning(hr + msg + hr)

self.queued_dttm = datetime.now()
Expand Down Expand Up @@ -1486,7 +1494,11 @@ def handle_failure(self, error, test_mode=False, context=None):

# Let's go deeper
try:
if task.retries and self.try_number % (task.retries + 1) != 0:
# try_number is incremented by 1 during task instance run. So the
# current task instance try_number is the try_number for the next
# task instance run. We only mark task instance as FAILED if the
# next task instance try_number exceeds the max_tries.
if task.retries and self.try_number <= self.max_tries:
self.state = State.UP_FOR_RETRY
logging.info('Marking task as UP_FOR_RETRY')
if task.email_on_retry and task.email:
Expand Down Expand Up @@ -1641,15 +1653,17 @@ def email_alert(self, exception, is_retry=False):
task = self.task
title = "Airflow alert: {self}".format(**locals())
exception = str(exception).replace('\n', '<br>')
try_ = task.retries + 1
# For reporting purposes, we report based on 1-indexed,
# not 0-indexed lists (i.e. Try 1 instead of
# Try 0 for the first attempt).
body = (
"Try {self.try_number} out of {try_}<br>"
"Try {try_number} out of {max_tries}<br>"
"Exception:<br>{exception}<br>"
"Log: <a href='{self.log_url}'>Link</a><br>"
"Host: {self.hostname}<br>"
"Log file: {self.log_filepath}<br>"
"Mark success: <a href='{self.mark_success_url}'>Link</a><br>"
).format(**locals())
).format(try_number=self.try_number + 1, max_tries=self.max_tries + 1, **locals())
send_email(task.email, title, body)

def set_duration(self):
Expand Down Expand Up @@ -2382,9 +2396,7 @@ def downstream_list(self):
def downstream_task_ids(self):
return self._downstream_task_ids

def clear(
self, start_date=None, end_date=None,
upstream=False, downstream=False):
def clear(self, start_date=None, end_date=None, upstream=False, downstream=False):
"""
Clears the state of task instances associated with the task, following
the parameters specified.
Expand Down Expand Up @@ -2413,7 +2425,7 @@ def clear(

count = qry.count()

clear_task_instances(qry.all(), session)
clear_task_instances(qry.all(), session, dag=self.dag)

session.commit()
session.close()
Expand Down Expand Up @@ -3244,7 +3256,7 @@ def clear(
do_it = utils.helpers.ask_yesno(question)

if do_it:
clear_task_instances(tis.all(), session)
clear_task_instances(tis.all(), session, dag=self)
if reset_dag_runs:
self.set_dag_runs_state(session=session)
else:
Expand Down
Loading

0 comments on commit b49986c

Please sign in to comment.