Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions providers/edge3/src/airflow/providers/edge3/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,48 +175,46 @@ def _get_state() -> EdgeWorkerState:
return EdgeWorkerState.MAINTENANCE_MODE
return EdgeWorkerState.IDLE

def _launch_job_af3(self, edge_job: EdgeJobFetched) -> tuple[Process, Path]:
if TYPE_CHECKING:
from airflow.executors.workloads import ExecuteTask
@staticmethod
def _run_job_via_supervisor(workload) -> int:
from airflow.sdk.execution_time.supervisor import supervise

def _run_job_via_supervisor(
workload: ExecuteTask,
) -> int:
from airflow.sdk.execution_time.supervisor import supervise
# Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion
signal.signal(signal.SIGINT, signal.SIG_IGN)

# Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion
signal.signal(signal.SIGINT, signal.SIG_IGN)
logger.info("Worker starting up pid=%d", os.getpid())
setproctitle(f"airflow edge worker: {workload.ti.key}")

logger.info("Worker starting up pid=%d", os.getpid())
setproctitle(f"airflow edge worker: {workload.ti.key}")
try:
api_url = conf.get("edge", "api_url")
execution_api_server_url = conf.get("core", "execution_api_server_url", fallback="")
if not execution_api_server_url:
parsed = urlparse(api_url)
execution_api_server_url = f"{parsed.scheme}://{parsed.netloc}/execution/"

supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
# Same like in airflow/executors/local_executor.py:_execute_work()
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=execution_api_server_url,
log_path=workload.log_path,
)
return 0
except Exception as e:
logger.exception("Task execution failed: %s", e)
return 1

try:
api_url = conf.get("edge", "api_url")
execution_api_server_url = conf.get("core", "execution_api_server_url", fallback="")
if not execution_api_server_url:
parsed = urlparse(api_url)
execution_api_server_url = f"{parsed.scheme}://{parsed.netloc}/execution/"

logger.info("Worker starting up server=execution_api_server_url=%s", execution_api_server_url)

supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
# Same like in airflow/executors/local_executor.py:_execute_work()
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=execution_api_server_url,
log_path=workload.log_path,
)
return 0
except Exception as e:
logger.exception("Task execution failed: %s", e)
return 1
@staticmethod
def _launch_job_af3(edge_job: EdgeJobFetched) -> tuple[Process, Path]:
if TYPE_CHECKING:
from airflow.executors.workloads import ExecuteTask

workload: ExecuteTask = edge_job.command
process = Process(
target=_run_job_via_supervisor,
target=EdgeWorker._run_job_via_supervisor,
kwargs={"workload": workload},
)
process.start()
Expand All @@ -226,7 +224,8 @@ def _run_job_via_supervisor(
logfile = Path(base_log_folder, workload.log_path)
return process, logfile

def _launch_job_af2_10(self, edge_job: EdgeJobFetched) -> tuple[Popen, Path]:
@staticmethod
def _launch_job_af2_10(edge_job: EdgeJobFetched) -> tuple[Popen, Path]:
"""Compatibility for Airflow 2.10 Launch."""
env = os.environ.copy()
env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
Expand All @@ -237,14 +236,15 @@ def _launch_job_af2_10(self, edge_job: EdgeJobFetched) -> tuple[Popen, Path]:
logfile = logs_logfile_path(edge_job.key)
return process, logfile

def _launch_job(self, edge_job: EdgeJobFetched):
@staticmethod
def _launch_job(edge_job: EdgeJobFetched):
"""Get the received job executed."""
process: Popen | Process
if AIRFLOW_V_3_0_PLUS:
process, logfile = self._launch_job_af3(edge_job)
process, logfile = EdgeWorker._launch_job_af3(edge_job)
else:
# Airflow 2.10
process, logfile = self._launch_job_af2_10(edge_job)
process, logfile = EdgeWorker._launch_job_af2_10(edge_job)
EdgeWorker.jobs.append(Job(edge_job, process, logfile, 0))

def start(self):
Expand Down Expand Up @@ -316,7 +316,7 @@ def fetch_job(self) -> bool:
edge_job = jobs_fetch(self.hostname, self.queues, self.free_concurrency)
if edge_job:
logger.info("Received job: %s", edge_job)
self._launch_job(edge_job)
EdgeWorker._launch_job(edge_job)
jobs_set_state(edge_job.key, TaskInstanceState.RUNNING)
return True

Expand Down