Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
@providers_configuration_loaded
def force_use_internal_api_on_edge_worker():
"""
Ensure that the environment is configured for the internal API without needing to declare it outside.
Ensure the environment is configured for the internal API without explicit declaration.

This is only required for an Edge worker and must to be done before the Click CLI wrapper is initiated.
That is because the CLI wrapper will attempt to establish a DB connection, which will fail before the
Expand Down
17 changes: 9 additions & 8 deletions providers/edge3/src/airflow/providers/edge3/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from subprocess import Popen
from time import sleep
from typing import TYPE_CHECKING
from urllib.parse import urlparse

from lockfile.pidlockfile import remove_existing_pidfile
from requests import HTTPError
Expand Down Expand Up @@ -186,11 +187,13 @@
setproctitle(f"airflow edge worker: {workload.ti.key}")

try:
base_url = conf.get("api", "base_url", fallback="/")
# If it's a relative URL, use localhost:8080 as the default
if base_url.startswith("/"):
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"
api_url = conf.get("edge", "api_url")
execution_api_server_url = conf.get("core", "execution_api_server_url", fallback=...)
if execution_api_server_url is ...:
parsed = urlparse(api_url)
execution_api_server_url = f"{parsed.scheme}://{parsed.netloc}/execution/"

logger.info("Worker starting up server=execution_api_server_url=%s", execution_api_server_url)

supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
Expand All @@ -199,9 +202,7 @@
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get(
"core", "execution_api_server_url", fallback=default_execution_api_server
),
server=execution_api_server_url,
log_path=workload.log_path,
)
return 0
Expand Down
46 changes: 46 additions & 0 deletions providers/edge3/tests/unit/edge3/cli/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,52 @@ def test_launch_job(self, mock_popen, mock_logfile_path, mock_process, worker_wi
assert len(EdgeWorker.jobs) == 1
assert EdgeWorker.jobs[0].edge_job == edge_job

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+")
@pytest.mark.parametrize(
"configs, expected_url",
[
(
{("edge", "api_url"): "https://api-endpoint"},
"https://api-endpoint/execution/",
),
(
{("edge", "api_url"): "https://api:1234/endpoint"},
"https://api:1234/execution/",
),
(
{
("edge", "api_url"): "https://api-endpoint",
("core", "execution_api_server_url"): "https://other-endpoint",
},
"https://other-endpoint",
),
],
)
@patch("airflow.sdk.execution_time.supervisor.supervise")
@patch("airflow.providers.edge3.cli.worker.Process")
@patch("airflow.providers.edge3.cli.worker.Popen")
def test_use_execution_api_server_url(
self,
mock_popen,
mock_process,
mock_supervise,
configs,
expected_url,
worker_with_job: EdgeWorker,
):
mock_popen.side_effect = [MagicMock()]
mock_process_instance = MagicMock()
mock_process.side_effect = [mock_process_instance]

edge_job = EdgeWorker.jobs.pop().edge_job
with conf_vars(configs):
worker_with_job._launch_job(edge_job)

mock_process_callback = mock_process.call_args.kwargs["target"]
mock_process_callback(workload=MagicMock())

assert mock_supervise.call_args.kwargs["server"] == expected_url

@pytest.mark.parametrize(
"reserve_result, fetch_result, expected_calls",
[
Expand Down