Skip to content

Commit

Permalink
Provider package Edge: Edge worker supports queue handling (#43115)
Browse files Browse the repository at this point in the history
* Edge worker supports queue handling

* Fix pytests

* Increment version

* remove unused import

* description to changelog

* version incremented

* Remove duplicate changelog

---------

Co-authored-by: Marco Küttelwesch <marco.kuettelwesch@de.bosch.com>
Co-authored-by: Majoros Donat (XC-DX/EET2-Bp) <donat.majoros2@hu.bosch.com>
  • Loading branch information
3 people authored Oct 18, 2024
1 parent 0de5587 commit 7767642
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 30 deletions.
8 changes: 8 additions & 0 deletions providers/src/airflow/providers/edge/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
Changelog
---------

0.2.0pre0
.........

Misc
~~~~

* ``Edge Worker can add or remove queues in the queue field in the DB (#43115)``

0.1.0pre0
.........

Expand Down
24 changes: 12 additions & 12 deletions providers/src/airflow/providers/edge/cli/edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,6 @@ def _hostname() -> str:
return os.uname()[1]


def _get_sysinfo() -> dict:
"""Produce the sysinfo from worker to post to central site."""
return {
"airflow_version": airflow_version,
"edge_provider_version": edge_provider_version,
}


def _pid_file_path(pid_file: str | None) -> str:
return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[0]

Expand Down Expand Up @@ -145,11 +137,19 @@ def signal_handler(sig, frame):
logger.info("Request to show down Edge Worker received, waiting for jobs to complete.")
_EdgeWorkerCli.drain = True

def _get_sysinfo(self) -> dict:
"""Produce the sysinfo from worker to post to central site."""
return {
"airflow_version": airflow_version,
"edge_provider_version": edge_provider_version,
"concurrency": self.concurrency,
}

def start(self):
"""Start the execution in a loop until terminated."""
try:
self.last_hb = EdgeWorker.register_worker(
self.hostname, EdgeWorkerState.STARTING, self.queues, _get_sysinfo()
self.hostname, EdgeWorkerState.STARTING, self.queues, self._get_sysinfo()
).last_update
except AirflowException as e:
if "404:NOT FOUND" in str(e):
Expand All @@ -162,7 +162,7 @@ def start(self):
self.loop()

logger.info("Quitting worker, signal being offline.")
EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, _get_sysinfo())
EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, self._get_sysinfo())
finally:
remove_existing_pidfile(self.pid_file_path)

Expand Down Expand Up @@ -230,8 +230,8 @@ def heartbeat(self) -> None:
if self.jobs
else EdgeWorkerState.IDLE
)
sysinfo = _get_sysinfo()
EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo)
sysinfo = self._get_sysinfo()
self.queues = EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo)

def interruptible_sleep(self):
"""Sleeps but stops sleeping if drain is made."""
Expand Down
55 changes: 51 additions & 4 deletions providers/src/airflow/providers/edge/models/edge_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import ast
import json
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -71,7 +72,7 @@ class EdgeWorkerModel(Base, LoggingMixin):
__tablename__ = "edge_worker"
worker_name = Column(String(64), primary_key=True, nullable=False)
state = Column(String(20))
queues = Column(String(256))
_queues = Column("queues", String(256))
first_online = Column(UtcDateTime)
last_update = Column(UtcDateTime)
jobs_active = Column(Integer, default=0)
Expand All @@ -90,7 +91,7 @@ def __init__(
):
self.worker_name = worker_name
self.state = state
self.queues = ", ".join(queues) if queues else None
self.queues = queues
self.first_online = first_online or timezone.utcnow()
self.last_update = last_update
super().__init__()
Expand All @@ -99,6 +100,33 @@ def __init__(
def sysinfo_json(self) -> dict:
return json.loads(self.sysinfo) if self.sysinfo else None

@property
def queues(self) -> list[str] | None:
"""Return list of queues which are stored in queues field."""
if self._queues:
return ast.literal_eval(self._queues)
return None

@queues.setter
def queues(self, queues: list[str] | None) -> None:
"""Set all queues of list into queues field."""
self._queues = str(queues) if queues else None

def add_queues(self, new_queues: list[str]) -> None:
"""Add new queue to the queues field."""
queues = self.queues if self.queues else []
queues.extend(new_queues)
# remove duplicated items
self.queues = list(set(queues))

def remove_queues(self, remove_queues: list[str]) -> None:
"""Remove queue from queues field."""
queues = self.queues if self.queues else []
for queue_name in remove_queues:
if queue_name in queues:
queues.remove(queue_name)
self.queues = queues


class EdgeWorker(BaseModel, LoggingMixin):
"""Accessor for Edge Worker instances as logical model."""
Expand Down Expand Up @@ -168,7 +196,7 @@ def register_worker(
return EdgeWorker(
worker_name=worker_name,
state=state,
queues=worker.queues,
queues=queues,
first_online=worker.first_online,
last_update=worker.last_update,
jobs_active=worker.jobs_active or 0,
Expand All @@ -187,7 +215,8 @@ def set_state(
jobs_active: int,
sysinfo: dict[str, str],
session: Session = NEW_SESSION,
):
) -> list[str] | None:
"""Set state of worker and returns the current assigned queues."""
EdgeWorker.assert_version(sysinfo)
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
Expand All @@ -196,6 +225,24 @@ def set_state(
worker.sysinfo = json.dumps(sysinfo)
worker.last_update = timezone.utcnow()
session.commit()
return worker.queues

@staticmethod
@provide_session
def add_and_remove_queues(
worker_name: str,
new_queues: list[str] | None = None,
remove_queues: list[str] | None = None,
session: Session = NEW_SESSION,
) -> None:
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
if new_queues:
worker.add_queues(new_queues)
if remove_queues:
worker.remove_queues(remove_queues)
session.add(worker)
session.commit()


EdgeWorker.model_rebuild()
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ state: not-ready
source-date-epoch: 1720863625
# note that those versions are maintained by release manager - do not update them manually
versions:
- 0.1.0pre0
- 0.2.0pre0

dependencies:
- apache-airflow>=2.10.0
Expand Down
21 changes: 14 additions & 7 deletions providers/tests/edge/cli/test_edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from airflow.exceptions import AirflowException
from airflow.providers.edge.cli.edge_command import (
_EdgeWorkerCli,
_get_sysinfo,
_Job,
)
from airflow.providers.edge.models.edge_job import EdgeJob
Expand All @@ -42,12 +41,6 @@
# mypy: disable-error-code="attr-defined"


def test_get_sysinfo():
sysinfo = _get_sysinfo()
assert "airflow_version" in sysinfo
assert "edge_provider_version" in sysinfo


class TestEdgeWorkerCli:
@pytest.fixture
def dummy_joblist(self, tmp_path: Path) -> list[_Job]:
Expand Down Expand Up @@ -208,9 +201,14 @@ def test_heartbeat(self, mock_set_state, drain, jobs, expected_state, worker_wit
if not jobs:
worker_with_job.jobs = []
_EdgeWorkerCli.drain = drain
mock_set_state.return_value = ["queue1", "queue2"]
with conf_vars({("edge", "api_url"): "https://mock.server"}):
worker_with_job.heartbeat()
assert mock_set_state.call_args.args[1] == expected_state
queue_list = worker_with_job.queues or []
assert len(queue_list) == 2
assert "queue1" in (queue_list)
assert "queue2" in (queue_list)

@patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker")
def test_start_missing_apiserver(self, mock_register_worker, worker_with_job: _EdgeWorkerCli):
Expand Down Expand Up @@ -258,3 +256,12 @@ def stop_running():
mock_register_worker.assert_called_once()
mock_loop.assert_called_once()
mock_set_state.assert_called_once()

def test_get_sysinfo(self, worker_with_job: _EdgeWorkerCli):
concurrency = 8
worker_with_job.concurrency = concurrency
sysinfo = worker_with_job._get_sysinfo()
assert "airflow_version" in sysinfo
assert "edge_provider_version" in sysinfo
assert "concurrency" in sysinfo
assert sysinfo["concurrency"] == concurrency
70 changes: 64 additions & 6 deletions providers/tests/edge/models/test_edge_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# under the License.
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import pytest

from airflow.providers.edge.cli.edge_command import _get_sysinfo
from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli
from airflow.providers.edge.models.edge_worker import (
EdgeWorker,
EdgeWorkerModel,
Expand All @@ -36,6 +37,11 @@


class TestEdgeWorker:
@pytest.fixture
def cli_worker(self, tmp_path: Path) -> _EdgeWorkerCli:
test_worker = _EdgeWorkerCli(tmp_path / "dummy.pid", "dummy", None, 8, 5, 5)
return test_worker

@pytest.fixture(autouse=True)
def setup_test_cases(self, session: Session):
session.query(EdgeWorkerModel).delete()
Expand Down Expand Up @@ -67,28 +73,80 @@ def test_assert_version(self):
{"airflow_version": airflow_version, "edge_provider_version": edge_provider_version}
)

def test_register_worker(self, session: Session):
@pytest.mark.parametrize(
"input_queues",
[
pytest.param(None, id="empty-queues"),
pytest.param(["default", "default2"], id="with-queues"),
],
)
def test_register_worker(
self, session: Session, input_queues: list[str] | None, cli_worker: _EdgeWorkerCli
):
EdgeWorker.register_worker(
"test_worker", EdgeWorkerState.STARTING, queues=None, sysinfo=_get_sysinfo()
"test_worker", EdgeWorkerState.STARTING, queues=input_queues, sysinfo=cli_worker._get_sysinfo()
)

worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
assert worker[0].worker_name == "test_worker"
if input_queues:
assert worker[0].queues == input_queues
else:
assert worker[0].queues is None

def test_set_state(self, session: Session):
def test_set_state(self, session: Session, cli_worker: _EdgeWorkerCli):
queues = ["default", "default2"]
rwm = EdgeWorkerModel(
worker_name="test2_worker",
state=EdgeWorkerState.IDLE,
queues=["default"],
queues=queues,
first_online=timezone.utcnow(),
)
session.add(rwm)
session.commit()

EdgeWorker.set_state("test2_worker", EdgeWorkerState.RUNNING, 1, _get_sysinfo())
return_queues = EdgeWorker.set_state(
"test2_worker", EdgeWorkerState.RUNNING, 1, cli_worker._get_sysinfo()
)

worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
assert worker[0].worker_name == "test2_worker"
assert worker[0].state == EdgeWorkerState.RUNNING
assert worker[0].queues == queues
assert return_queues == ["default", "default2"]

@pytest.mark.parametrize(
"add_queues, remove_queues, expected_queues",
[
pytest.param(None, None, ["init"], id="no-changes"),
pytest.param(
["queue1", "queue2"], ["queue1", "queue_not_existing"], ["init", "queue2"], id="add-remove"
),
pytest.param(["init"], None, ["init"], id="check-duplicated"),
],
)
def test_add_and_remove_queues(
self,
session: Session,
add_queues: list[str] | None,
remove_queues: list[str] | None,
expected_queues: list[str],
cli_worker: _EdgeWorkerCli,
):
rwm = EdgeWorkerModel(
worker_name="test2_worker",
state=EdgeWorkerState.IDLE,
queues=["init"],
first_online=timezone.utcnow(),
)
session.add(rwm)
session.commit()
EdgeWorker.add_and_remove_queues("test2_worker", add_queues, remove_queues, session)
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
assert len(worker) == 1
assert worker[0].worker_name == "test2_worker"
assert len(expected_queues) == len(worker[0].queues or [])
for expected_queue in expected_queues:
assert expected_queue in (worker[0].queues or [])

0 comments on commit 7767642

Please sign in to comment.