Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

PARALLELISM: int = conf.getint("core", "PARALLELISM")

Expand Down Expand Up @@ -296,7 +296,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)

def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
"""
Changes state of the task.

Expand All @@ -318,7 +318,7 @@ def fail(self, key: TaskInstanceKey, info=None) -> None:
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, State.FAILED, info)
self.change_state(key, TaskInstanceState.FAILED, info)

def success(self, key: TaskInstanceKey, info=None) -> None:
"""
Expand All @@ -327,7 +327,7 @@ def success(self, key: TaskInstanceKey, info=None) -> None:
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, State.SUCCESS, info)
self.change_state(key, TaskInstanceState.SUCCESS, info)

def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Expand Down
28 changes: 14 additions & 14 deletions airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import TYPE_CHECKING, Any

from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
Expand Down Expand Up @@ -68,15 +68,15 @@ def sync(self) -> None:
while self.tasks_to_run:
ti = self.tasks_to_run.pop(0)
if self.fail_fast and not task_succeeded:
self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
ti.set_state(State.UPSTREAM_FAILED)
self.change_state(ti.key, State.UPSTREAM_FAILED)
self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED)
ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)
continue

if self._terminated.is_set():
self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
ti.set_state(State.FAILED)
self.change_state(ti.key, State.FAILED)
self.log.info("Executor is terminated! Stopping %s to %s", ti.key, TaskInstanceState.FAILED)
ti.set_state(TaskInstanceState.FAILED)
self.change_state(ti.key, TaskInstanceState.FAILED)
continue

task_succeeded = self._run_task(ti)
Expand All @@ -87,11 +87,11 @@ def _run_task(self, ti: TaskInstance) -> bool:
try:
params = self.tasks_params.pop(ti.key, {})
ti.run(job_id=ti.job_id, **params)
self.change_state(key, State.SUCCESS)
self.change_state(key, TaskInstanceState.SUCCESS)
return True
except Exception as e:
ti.set_state(State.FAILED)
self.change_state(key, State.FAILED)
ti.set_state(TaskInstanceState.FAILED)
self.change_state(key, TaskInstanceState.FAILED)
self.log.exception("Failed to execute task: %s.", str(e))
return False

Expand Down Expand Up @@ -148,14 +148,14 @@ def trigger_tasks(self, open_slots: int) -> None:
def end(self) -> None:
"""Set states of queued tasks to UPSTREAM_FAILED marking them as not executed."""
for ti in self.tasks_to_run:
self.log.info("Setting %s to %s", ti.key, State.UPSTREAM_FAILED)
ti.set_state(State.UPSTREAM_FAILED)
self.change_state(ti.key, State.UPSTREAM_FAILED)
self.log.info("Setting %s to %s", ti.key, TaskInstanceState.UPSTREAM_FAILED)
ti.set_state(TaskInstanceState.UPSTREAM_FAILED)
self.change_state(ti.key, TaskInstanceState.UPSTREAM_FAILED)

def terminate(self) -> None:
self._terminated.set()

def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
self.log.debug("Popping %s from executor task queue.", key)
self.running.remove(key)
self.event_buffer[key] = state, info
6 changes: 3 additions & 3 deletions airflow/executors/sequential_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import TYPE_CHECKING, Any

from airflow.executors.base_executor import BaseExecutor
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
Expand Down Expand Up @@ -75,9 +75,9 @@ def sync(self) -> None:

try:
subprocess.check_call(command, close_fds=True)
self.change_state(key, State.SUCCESS)
self.change_state(key, TaskInstanceState.SUCCESS)
except subprocess.CalledProcessError as e:
self.change_state(key, State.FAILED)
self.change_state(key, TaskInstanceState.FAILED)
self.log.error("Failed to execute task %s.", str(e))

self.commands_to_run = []
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/celery/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from airflow.exceptions import AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.stats import Stats
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -300,7 +300,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
self.task_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error(CELERY_SEND_ERR_MSG_HEADER + ": %s\n%s\n", result.exception, result.traceback)
self.event_buffer[key] = (State.FAILED, None)
self.event_buffer[key] = (TaskInstanceState.FAILED, None)
elif result is not None:
result.backend = cached_celery_backend
self.running.add(key)
Expand All @@ -309,7 +309,7 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other events are success/failed at
# which point we don't need the ID anymore anyway
self.event_buffer[key] = (State.QUEUED, result.task_id)
self.event_buffer[key] = (TaskInstanceState.QUEUED, result.task_id)

# If the task runs _really quickly_ we may already have a result!
self.update_task_state(key, result.state, getattr(result, "info", None))
Expand Down Expand Up @@ -356,7 +356,7 @@ def update_all_task_states(self) -> None:
if state:
self.update_task_state(key, state, info)

def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, info=None) -> None:
super().change_state(key, state, info)
self.tasks.pop(key, None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.log.logging_mixin import remove_escape_codes
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from kubernetes import client
Expand Down Expand Up @@ -426,20 +426,20 @@ def sync(self) -> None:
def _change_state(
self,
key: TaskInstanceKey,
state: str | None,
state: TaskInstanceState | None,
pod_name: str,
namespace: str,
session: Session = NEW_SESSION,
) -> None:
if TYPE_CHECKING:
assert self.kube_scheduler

if state == State.RUNNING:
if state == TaskInstanceState.RUNNING:
self.event_buffer[key] = state, None
return

if self.kube_config.delete_worker_pods:
if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure:
if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace)
self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace))
else:
Expand All @@ -456,6 +456,7 @@ def _change_state(
from airflow.models.taskinstance import TaskInstance

state = session.query(TaskInstance.state).filter(TaskInstance.filter_for_tis([key])).scalar()
state = TaskInstanceState(state)

self.event_buffer[key] = state, None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.state import TaskInstanceState

# TaskInstance key, command, configuration, pod_template_file
KubernetesJobType = Tuple[TaskInstanceKey, CommandType, Any, Optional[str]]

# key, pod state, pod_name, namespace, resource_version
KubernetesResultsType = Tuple[TaskInstanceKey, Optional[str], str, str, str]
KubernetesResultsType = Tuple[TaskInstanceKey, Optional[TaskInstanceState], str, str, str]

# pod_name, namespace, pod state, annotations, resource_version
KubernetesWatchType = Tuple[str, str, Optional[str], Dict[str, str], str]
KubernetesWatchType = Tuple[str, str, Optional[TaskInstanceState], Dict[str, str], str]

ALL_NAMESPACES = "ALL_NAMESPACES"
POD_EXECUTOR_DONE_KEY = "airflow_executor_done"
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
from airflow.utils.state import TaskInstanceState

try:
from airflow.providers.cncf.kubernetes.executors.kubernetes_executor_types import (
Expand Down Expand Up @@ -223,12 +223,16 @@ def process_status(
# since kube server have received request to delete pod set TI state failed
if event["type"] == "DELETED" and pod.metadata.deletion_timestamp:
self.log.info("Event: Failed to start pod %s, annotations: %s", pod_name, annotations_string)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
self.watcher_queue.put(
(pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
)
else:
self.log.debug("Event: %s Pending, annotations: %s", pod_name, annotations_string)
elif status == "Failed":
self.log.error("Event: %s Failed, annotations: %s", pod_name, annotations_string)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
self.watcher_queue.put(
(pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
)
elif status == "Succeeded":
# We get multiple events once the pod hits a terminal state, and we only want to
# send it along to the scheduler once.
Expand Down Expand Up @@ -256,7 +260,9 @@ def process_status(
pod_name,
annotations_string,
)
self.watcher_queue.put((pod_name, namespace, State.FAILED, annotations, resource_version))
self.watcher_queue.put(
(pod_name, namespace, TaskInstanceState.FAILED, annotations, resource_version)
)
else:
self.log.info("Event: %s is Running, annotations: %s", pod_name, annotations_string)
else:
Expand Down