Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions providers/src/airflow/providers/edge/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
Changelog
---------

0.8.2pre0
.........

Misc
~~~~

* ``Migrate worker job calls to FastAPI.``

0.8.1pre0
.........

Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

__all__ = ["__version__"]

__version__ = "0.8.1pre0"
__version__ = "0.8.2pre0"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.10.0"
Expand Down
34 changes: 32 additions & 2 deletions providers/src/airflow/providers/edge/cli/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.edge.worker_api.auth import jwt_signer
from airflow.providers.edge.worker_api.datamodels import PushLogsBody, WorkerStateBody
from airflow.providers.edge.worker_api.datamodels import (
EdgeJobFetched,
PushLogsBody,
WorkerQueuesBody,
WorkerStateBody,
)
from airflow.utils.state import TaskInstanceState # noqa: TC001

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -114,6 +120,28 @@ def worker_set_state(
)


def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
result = _make_generic_request(
"GET",
f"jobs/fetch/{quote(hostname)}",
WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency).model_dump_json(
exclude_unset=True
),
)
if result:
return EdgeJobFetched(**result)
return None


def jobs_set_state(key: TaskInstanceKey, state: TaskInstanceState) -> None:
"""Set the state of a job."""
_make_generic_request(
"PATCH",
f"jobs/state/{key.dag_id}/{key.task_id}/{key.run_id}/{key.try_number}/{key.map_index}/{state}",
)


def logs_logfile_path(task: TaskInstanceKey) -> Path:
"""Elaborate the path and filename to expect from task execution."""
result = _make_generic_request(
Expand All @@ -133,5 +161,7 @@ def logs_push(
_make_generic_request(
"POST",
f"logs/push/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
PushLogsBody(log_chunk_time=log_chunk_time, log_chunk_data=log_chunk_data).model_dump_json(),
PushLogsBody(log_chunk_time=log_chunk_time, log_chunk_data=log_chunk_data).model_dump_json(
exclude_unset=True
),
)
21 changes: 12 additions & 9 deletions providers/src/airflow/providers/edge/cli/edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pathlib import Path
from subprocess import Popen
from time import sleep
from typing import TYPE_CHECKING

import psutil
from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile, write_pid_to_pidfile
Expand All @@ -37,18 +38,22 @@
from airflow.exceptions import AirflowException
from airflow.providers.edge import __version__ as edge_provider_version
from airflow.providers.edge.cli.api_client import (
jobs_fetch,
jobs_set_state,
logs_logfile_path,
logs_push,
worker_register,
worker_set_state,
)
from airflow.providers.edge.models.edge_job import EdgeJob
from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.platform import IS_WINDOWS
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched

logger = logging.getLogger(__name__)
EDGE_WORKER_PROCESS_NAME = "edge-worker"
EDGE_WORKER_HEADER = "\n".join(
Expand Down Expand Up @@ -81,7 +86,7 @@ def force_use_internal_api_on_edge_worker():
if AIRFLOW_V_3_0_PLUS:
# Obvious TODO Make EdgeWorker compatible with Airflow 3 (again)
raise SystemExit(
"Error: EdgeWorker is currently broken on AIrflow 3/main due to removal of AIP-44, rework for AIP-72."
"Error: EdgeWorker is currently broken on Airflow 3/main due to removal of AIP-44, rework for AIP-72."
)

api_url = conf.get("edge", "api_url")
Expand Down Expand Up @@ -141,7 +146,7 @@ def _write_pid_to_pidfile(pid_file_path: str):
class _Job:
"""Holds all information for a task/job to be executed as bundle."""

edge_job: EdgeJob
edge_job: EdgeJobFetched
process: Popen
logfile: Path
logsize: int
Expand Down Expand Up @@ -240,9 +245,7 @@ def loop(self):
def fetch_job(self) -> bool:
"""Fetch and start a new job from central site."""
logger.debug("Attempting to fetch a new job...")
edge_job = EdgeJob.reserve_task(
worker_name=self.hostname, free_concurrency=self.free_concurrency, queues=self.queues
)
edge_job = jobs_fetch(self.hostname, self.queues, self.free_concurrency)
if edge_job:
logger.info("Received job: %s", edge_job)
env = os.environ.copy()
Expand All @@ -252,7 +255,7 @@ def fetch_job(self) -> bool:
process = Popen(edge_job.command, close_fds=True, env=env, start_new_session=True)
logfile = logs_logfile_path(edge_job.key)
self.jobs.append(_Job(edge_job, process, logfile, 0))
EdgeJob.set_state(edge_job.key, TaskInstanceState.RUNNING)
jobs_set_state(edge_job.key, TaskInstanceState.RUNNING)
return True

logger.info("No new job to process%s", f", {len(self.jobs)} still running" if self.jobs else "")
Expand All @@ -268,10 +271,10 @@ def check_running_jobs(self) -> None:
self.jobs.remove(job)
if job.process.returncode == 0:
logger.info("Job completed: %s", job.edge_job)
EdgeJob.set_state(job.edge_job.key, TaskInstanceState.SUCCESS)
jobs_set_state(job.edge_job.key, TaskInstanceState.SUCCESS)
else:
logger.error("Job failed: %s", job.edge_job)
EdgeJob.set_state(job.edge_job.key, TaskInstanceState.FAILED)
jobs_set_state(job.edge_job.key, TaskInstanceState.FAILED)
else:
used_concurrency += job.edge_job.concurrency_slots

Expand Down
Loading