diff --git a/airflow/sensors/external_task_sensor.py b/airflow/sensors/external_task_sensor.py index 1b93528ad32b1..2dc08750fed4e 100644 --- a/airflow/sensors/external_task_sensor.py +++ b/airflow/sensors/external_task_sensor.py @@ -104,7 +104,7 @@ def poke(self, context, session=None): if self.execution_delta: dttm = context['execution_date'] - self.execution_delta elif self.execution_date_fn: - dttm = self.execution_date_fn(context['execution_date']) + dttm = self._handle_execution_date_fn(context=context) else: dttm = context['execution_date'] @@ -159,6 +159,26 @@ def poke(self, context, session=None): session.commit() return count == len(dttm_filter) + def _handle_execution_date_fn(self, context): + """ + This function is to handle backwards compatibility with how this operator was + previously where it only passes the execution date, but also allow for the newer + implementation to pass all context through as well, to allow for more sophisticated + returns of dates to return. + Namely, this function check the number of arguments in the execution_date_fn + signature and if its 1, treat the legacy way, if it's 2, pass the context as + the 2nd argument, and if its more, throw an exception. + """ + num_fxn_params = self.execution_date_fn.__code__.co_argcount + if num_fxn_params == 1: + return self.execution_date_fn(context['execution_date']) + elif num_fxn_params == 2: + return self.execution_date_fn(context['execution_date'], context) + else: + raise AirflowException( + 'execution_date_fn passed {} args but only allowed up to 2'.format(num_fxn_params) + ) + class ExternalTaskMarker(DummyOperator): """ diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 25ef8b70490ab..0e5e960abbe3f 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -251,6 +251,28 @@ def test_external_task_sensor_fn(self): ignore_ti_state=True ) + def test_external_task_sensor_fn_multiple_args(self): + """Check this task sensor passes multiple args with full context. If no failure, means clean run.""" + self.test_time_sensor() + + def my_func(dt, context): + assert context['execution_date'] == dt + return dt + timedelta(0) + + op1 = ExternalTaskSensor( + task_id='test_external_task_sensor_multiple_arg_fn', + external_dag_id=TEST_DAG_ID, + external_task_id=TEST_TASK_ID, + execution_date_fn=my_func, + allowed_states=['success'], + dag=self.dag + ) + op1.run( + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, + ignore_ti_state=True + ) + def test_external_task_sensor_error_delta_and_fn(self): self.test_time_sensor() # Test that providing execution_delta and a function raises an error