diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst b/providers/src/airflow/providers/edge/CHANGELOG.rst index 57309ff6cde66..6422484c4cb4f 100644 --- a/providers/src/airflow/providers/edge/CHANGELOG.rst +++ b/providers/src/airflow/providers/edge/CHANGELOG.rst @@ -27,6 +27,14 @@ Changelog --------- +0.2.0pre0 +......... + +Misc +~~~~ + +* ``Edge Worker can add or remove queues in the queue field in the DB (#43115)`` + 0.1.0pre0 ......... diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py index 09998ffe80281..7e0b3f5b0b6b3 100644 --- a/providers/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/src/airflow/providers/edge/cli/edge_command.py @@ -91,14 +91,6 @@ def _hostname() -> str: return os.uname()[1] -def _get_sysinfo() -> dict: - """Produce the sysinfo from worker to post to central site.""" - return { - "airflow_version": airflow_version, - "edge_provider_version": edge_provider_version, - } - - def _pid_file_path(pid_file: str | None) -> str: return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[0] @@ -145,11 +137,19 @@ def signal_handler(sig, frame): logger.info("Request to show down Edge Worker received, waiting for jobs to complete.") _EdgeWorkerCli.drain = True + def _get_sysinfo(self) -> dict: + """Produce the sysinfo from worker to post to central site.""" + return { + "airflow_version": airflow_version, + "edge_provider_version": edge_provider_version, + "concurrency": self.concurrency, + } + def start(self): """Start the execution in a loop until terminated.""" try: self.last_hb = EdgeWorker.register_worker( - self.hostname, EdgeWorkerState.STARTING, self.queues, _get_sysinfo() + self.hostname, EdgeWorkerState.STARTING, self.queues, self._get_sysinfo() ).last_update except AirflowException as e: if "404:NOT FOUND" in str(e): @@ -162,7 +162,7 @@ def start(self): self.loop() logger.info("Quitting worker, signal being offline.") - EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, _get_sysinfo()) + EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, self._get_sysinfo()) finally: remove_existing_pidfile(self.pid_file_path) @@ -230,8 +230,8 @@ def heartbeat(self) -> None: if self.jobs else EdgeWorkerState.IDLE ) - sysinfo = _get_sysinfo() - EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo) + sysinfo = self._get_sysinfo() + self.queues = EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo) def interruptible_sleep(self): """Sleeps but stops sleeping if drain is made.""" diff --git a/providers/src/airflow/providers/edge/models/edge_worker.py b/providers/src/airflow/providers/edge/models/edge_worker.py index 193795e37d069..ff33c13c72ffd 100644 --- a/providers/src/airflow/providers/edge/models/edge_worker.py +++ b/providers/src/airflow/providers/edge/models/edge_worker.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import ast import json from datetime import datetime from enum import Enum @@ -71,7 +72,7 @@ class EdgeWorkerModel(Base, LoggingMixin): __tablename__ = "edge_worker" worker_name = Column(String(64), primary_key=True, nullable=False) state = Column(String(20)) - queues = Column(String(256)) + _queues = Column("queues", String(256)) first_online = Column(UtcDateTime) last_update = Column(UtcDateTime) jobs_active = Column(Integer, default=0) @@ -90,7 +91,7 @@ def __init__( ): self.worker_name = worker_name self.state = state - self.queues = ", ".join(queues) if queues else None + self.queues = queues self.first_online = first_online or timezone.utcnow() self.last_update = last_update super().__init__() @@ -99,6 +100,33 @@ def __init__( def sysinfo_json(self) -> dict: return json.loads(self.sysinfo) if self.sysinfo else None + @property + def queues(self) -> list[str] | None: + """Return list of queues which are stored in queues field.""" + if self._queues: + return ast.literal_eval(self._queues) + return None + + @queues.setter + def queues(self, queues: list[str] | None) -> None: + """Set all queues of list into queues field.""" + self._queues = str(queues) if queues else None + + def add_queues(self, new_queues: list[str]) -> None: + """Add new queue to the queues field.""" + queues = self.queues if self.queues else [] + queues.extend(new_queues) + # remove duplicated items + self.queues = list(set(queues)) + + def remove_queues(self, remove_queues: list[str]) -> None: + """Remove queue from queues field.""" + queues = self.queues if self.queues else [] + for queue_name in remove_queues: + if queue_name in queues: + queues.remove(queue_name) + self.queues = queues + class EdgeWorker(BaseModel, LoggingMixin): """Accessor for Edge Worker instances as logical model.""" @@ -168,7 +196,7 @@ def register_worker( return EdgeWorker( worker_name=worker_name, state=state, - queues=worker.queues, + queues=queues, first_online=worker.first_online, last_update=worker.last_update, jobs_active=worker.jobs_active or 0, @@ -187,7 +215,8 @@ def set_state( jobs_active: int, sysinfo: dict[str, str], session: Session = NEW_SESSION, - ): + ) -> list[str] | None: + """Set state of worker and returns the current assigned queues.""" EdgeWorker.assert_version(sysinfo) query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) worker: EdgeWorkerModel = session.scalar(query) @@ -196,6 +225,24 @@ def set_state( worker.sysinfo = json.dumps(sysinfo) worker.last_update = timezone.utcnow() session.commit() + return worker.queues + + @staticmethod + @provide_session + def add_and_remove_queues( + worker_name: str, + new_queues: list[str] | None = None, + remove_queues: list[str] | None = None, + session: Session = NEW_SESSION, + ) -> None: + query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) + worker: EdgeWorkerModel = session.scalar(query) + if new_queues: + worker.add_queues(new_queues) + if remove_queues: + worker.remove_queues(remove_queues) + session.add(worker) + session.commit() EdgeWorker.model_rebuild() diff --git a/providers/src/airflow/providers/edge/provider.yaml b/providers/src/airflow/providers/edge/provider.yaml index 7f61ad128c5c7..7e5767021ef68 100644 --- a/providers/src/airflow/providers/edge/provider.yaml +++ b/providers/src/airflow/providers/edge/provider.yaml @@ -26,7 +26,7 @@ state: not-ready source-date-epoch: 1720863625 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.1.0pre0 + - 0.2.0pre0 dependencies: - apache-airflow>=2.10.0 diff --git a/providers/tests/edge/cli/test_edge_command.py b/providers/tests/edge/cli/test_edge_command.py index 13bcfcc55700c..53e93117ac00d 100644 --- a/providers/tests/edge/cli/test_edge_command.py +++ b/providers/tests/edge/cli/test_edge_command.py @@ -27,7 +27,6 @@ from airflow.exceptions import AirflowException from airflow.providers.edge.cli.edge_command import ( _EdgeWorkerCli, - _get_sysinfo, _Job, ) from airflow.providers.edge.models.edge_job import EdgeJob @@ -42,12 +41,6 @@ # mypy: disable-error-code="attr-defined" -def test_get_sysinfo(): - sysinfo = _get_sysinfo() - assert "airflow_version" in sysinfo - assert "edge_provider_version" in sysinfo - - class TestEdgeWorkerCli: @pytest.fixture def dummy_joblist(self, tmp_path: Path) -> list[_Job]: @@ -208,9 +201,14 @@ def test_heartbeat(self, mock_set_state, drain, jobs, expected_state, worker_wit if not jobs: worker_with_job.jobs = [] _EdgeWorkerCli.drain = drain + mock_set_state.return_value = ["queue1", "queue2"] with conf_vars({("edge", "api_url"): "https://mock.server"}): worker_with_job.heartbeat() assert mock_set_state.call_args.args[1] == expected_state + queue_list = worker_with_job.queues or [] + assert len(queue_list) == 2 + assert "queue1" in (queue_list) + assert "queue2" in (queue_list) @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker") def test_start_missing_apiserver(self, mock_register_worker, worker_with_job: _EdgeWorkerCli): @@ -258,3 +256,12 @@ def stop_running(): mock_register_worker.assert_called_once() mock_loop.assert_called_once() mock_set_state.assert_called_once() + + def test_get_sysinfo(self, worker_with_job: _EdgeWorkerCli): + concurrency = 8 + worker_with_job.concurrency = concurrency + sysinfo = worker_with_job._get_sysinfo() + assert "airflow_version" in sysinfo + assert "edge_provider_version" in sysinfo + assert "concurrency" in sysinfo + assert sysinfo["concurrency"] == concurrency diff --git a/providers/tests/edge/models/test_edge_worker.py b/providers/tests/edge/models/test_edge_worker.py index f0e0ac9dfa056..d67cfdbb2cd61 100644 --- a/providers/tests/edge/models/test_edge_worker.py +++ b/providers/tests/edge/models/test_edge_worker.py @@ -16,11 +16,12 @@ # under the License. from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING import pytest -from airflow.providers.edge.cli.edge_command import _get_sysinfo +from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli from airflow.providers.edge.models.edge_worker import ( EdgeWorker, EdgeWorkerModel, @@ -36,6 +37,11 @@ class TestEdgeWorker: + @pytest.fixture + def cli_worker(self, tmp_path: Path) -> _EdgeWorkerCli: + test_worker = _EdgeWorkerCli(tmp_path / "dummy.pid", "dummy", None, 8, 5, 5) + return test_worker + @pytest.fixture(autouse=True) def setup_test_cases(self, session: Session): session.query(EdgeWorkerModel).delete() @@ -67,28 +73,80 @@ def test_assert_version(self): {"airflow_version": airflow_version, "edge_provider_version": edge_provider_version} ) - def test_register_worker(self, session: Session): + @pytest.mark.parametrize( + "input_queues", + [ + pytest.param(None, id="empty-queues"), + pytest.param(["default", "default2"], id="with-queues"), + ], + ) + def test_register_worker( + self, session: Session, input_queues: list[str] | None, cli_worker: _EdgeWorkerCli + ): EdgeWorker.register_worker( - "test_worker", EdgeWorkerState.STARTING, queues=None, sysinfo=_get_sysinfo() + "test_worker", EdgeWorkerState.STARTING, queues=input_queues, sysinfo=cli_worker._get_sysinfo() ) worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() assert len(worker) == 1 assert worker[0].worker_name == "test_worker" + if input_queues: + assert worker[0].queues == input_queues + else: + assert worker[0].queues is None - def test_set_state(self, session: Session): + def test_set_state(self, session: Session, cli_worker: _EdgeWorkerCli): + queues = ["default", "default2"] rwm = EdgeWorkerModel( worker_name="test2_worker", state=EdgeWorkerState.IDLE, - queues=["default"], + queues=queues, first_online=timezone.utcnow(), ) session.add(rwm) session.commit() - EdgeWorker.set_state("test2_worker", EdgeWorkerState.RUNNING, 1, _get_sysinfo()) + return_queues = EdgeWorker.set_state( + "test2_worker", EdgeWorkerState.RUNNING, 1, cli_worker._get_sysinfo() + ) worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() assert len(worker) == 1 assert worker[0].worker_name == "test2_worker" assert worker[0].state == EdgeWorkerState.RUNNING + assert worker[0].queues == queues + assert return_queues == ["default", "default2"] + + @pytest.mark.parametrize( + "add_queues, remove_queues, expected_queues", + [ + pytest.param(None, None, ["init"], id="no-changes"), + pytest.param( + ["queue1", "queue2"], ["queue1", "queue_not_existing"], ["init", "queue2"], id="add-remove" + ), + pytest.param(["init"], None, ["init"], id="check-duplicated"), + ], + ) + def test_add_and_remove_queues( + self, + session: Session, + add_queues: list[str] | None, + remove_queues: list[str] | None, + expected_queues: list[str], + cli_worker: _EdgeWorkerCli, + ): + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=EdgeWorkerState.IDLE, + queues=["init"], + first_online=timezone.utcnow(), + ) + session.add(rwm) + session.commit() + EdgeWorker.add_and_remove_queues("test2_worker", add_queues, remove_queues, session) + worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() + assert len(worker) == 1 + assert worker[0].worker_name == "test2_worker" + assert len(expected_queues) == len(worker[0].queues or []) + for expected_queue in expected_queues: + assert expected_queue in (worker[0].queues or [])