Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,40 @@ def execute_async(
del self.edge_queued_tasks[key]

self.validate_airflow_tasks_run_command(command) # type: ignore[attr-defined]
session.add(
EdgeJobModel(

# Check if job already exists with same dag_id, task_id, run_id, map_index, try_number
existing_job = (
session.query(EdgeJobModel)
.filter_by(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=queue or DEFAULT_QUEUE,
concurrency_slots=task_instance.pool_slots,
command=str(command),
)
.first()
)

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
existing_job.queue = queue or DEFAULT_QUEUE
existing_job.concurrency_slots = task_instance.pool_slots
existing_job.command = str(command)
else:
session.add(
EdgeJobModel(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=queue or DEFAULT_QUEUE,
concurrency_slots=task_instance.pool_slots,
command=str(command),
)
)

@provide_session
def queue_workload(
self,
Expand All @@ -168,20 +188,40 @@ def queue_workload(

task_instance = workload.ti
key = task_instance.key
session.add(
EdgeJobModel(

# Check if job already exists with same dag_id, task_id, run_id, map_index, try_number
existing_job = (
session.query(EdgeJobModel)
.filter_by(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=task_instance.queue,
concurrency_slots=task_instance.pool_slots,
command=workload.model_dump_json(),
)
.first()
)

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
existing_job.queue = task_instance.queue
existing_job.concurrency_slots = task_instance.pool_slots
existing_job.command = workload.model_dump_json()
else:
session.add(
EdgeJobModel(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=task_instance.queue,
concurrency_slots=task_instance.pool_slots,
command=workload.model_dump_json(),
)
)

def _check_worker_liveness(self, session: Session) -> bool:
"""Reset worker state if heartbeat timed out."""
changed = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ def test_sync_active_worker(self):

# Prepare some data
with create_session() as session:
# Clear existing workers to avoid unique constraint violation
session.query(EdgeWorkerModel).delete()
session.commit()

# Add workers with different states
for worker_name, state, last_heartbeat in [
(
"inactive_timed_out_worker",
Expand Down Expand Up @@ -338,3 +343,95 @@ def test_queue_workload(self):
with create_session() as session:
jobs = session.query(EdgeJobModel).all()
assert len(jobs) == 1

@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="API only available in Airflow <3.0")
def test_execute_async_updates_existing_job(self):
executor, key = self.get_test_executor()

# First insert a job with the same key
with create_session() as session:
session.add(
EdgeJobModel(
dag_id=key.dag_id,
run_id=key.run_id,
task_id=key.task_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.SCHEDULED,
queue="default",
concurrency_slots=1,
command="old-command",
last_update=timezone.utcnow(),
)
)
session.commit()

# Trigger execute_async which should update the existing job
executor.edge_queued_tasks = deepcopy(executor.queued_tasks)
executor.execute_async(key=key, command=["airflow", "tasks", "run", "new", "command"])

with create_session() as session:
jobs = session.query(EdgeJobModel).all()
assert len(jobs) == 1
job = jobs[0]
assert job.state == TaskInstanceState.QUEUED
assert job.command != "old-command"
assert "new" in job.command

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="API only available in Airflow 3.0+")
def test_queue_workload_updates_existing_job(self):
from uuid import uuid4

from airflow.executors.workloads import ExecuteTask, TaskInstance

executor = self.get_test_executor()[0]

key = TaskInstanceKey(dag_id="mock", run_id="mock", task_id="mock", map_index=-1, try_number=1)

# Insert an existing job
with create_session() as session:
session.add(
EdgeJobModel(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.SCHEDULED,
queue="default",
command="old-command",
concurrency_slots=1,
last_update=timezone.utcnow(),
)
)
session.commit()

# Queue a workload with same key
workload = ExecuteTask(
token="mock",
ti=TaskInstance(
id=uuid4(),
task_id=key.task_id,
dag_id=key.dag_id,
run_id=key.run_id,
try_number=key.try_number,
map_index=key.map_index,
pool_slots=1,
queue="updated-queue",
priority_weight=1,
start_date=timezone.utcnow(),
dag_version_id=uuid4(),
),
dag_rel_path="mock.py",
log_path="mock.log",
bundle_info={"name": "n/a", "version": "no matter"},
)

executor.queue_workload(workload=workload)

with create_session() as session:
jobs = session.query(EdgeJobModel).all()
assert len(jobs) == 1
job = jobs[0]
assert job.queue == "updated-queue"
assert job.command != "old-command"
Loading