Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 50 additions & 25 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,9 +1504,16 @@ def _start_queued_dagruns(self, session: Session) -> None:
# added all() to save runtime, otherwise query is executed more than once
dag_runs: Collection[DagRun] = DagRun.get_queued_dag_runs_to_set_running(session).all()

active_runs_of_dags = Counter(
DagRun.active_runs_of_dags((dr.dag_id for dr in dag_runs), only_running=True, session=session),
query = (
select(
DagRun.dag_id,
DagRun.backfill_id,
func.count(DagRun.id).label("num_running"),
)
.where(DagRun.state == DagRunState.RUNNING)
.group_by(DagRun.dag_id, DagRun.backfill_id)
)
active_runs_of_dags = Counter({(dag_id, br_id): num for dag_id, br_id, num in session.execute(query)})

@add_span
def _update_state(dag: DAG, dag_run: DagRun):
Expand Down Expand Up @@ -1548,33 +1555,51 @@ def _update_state(dag: DAG, dag_run: DagRun):

span = Trace.get_current_span()
for dag_run in dag_runs:
dag = dag_run.dag = cached_get_dag(dag_run.dag_id)

dag_id = dag_run.dag_id
run_id = dag_run.run_id
backfill_id = dag_run.backfill_id
backfill = dag_run.backfill
dag = dag_run.dag = cached_get_dag(dag_id)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
continue
active_runs = active_runs_of_dags[dag_run.dag_id]

if dag.max_active_runs and active_runs >= dag.max_active_runs:
self.log.debug(
"DAG %s already has %d active runs, not moving any more runs to RUNNING state %s",
dag.dag_id,
active_runs,
dag_run.execution_date,
)
else:
if span.is_recording():
span.add_event(
name="dag_run",
attributes={
"run_id": dag_run.run_id,
"dag_id": dag_run.dag_id,
"conf": str(dag_run.conf),
},
active_runs = active_runs_of_dags[(dag_id, backfill_id)]
if backfill_id is not None:
if active_runs >= backfill.max_active_runs:
# todo: delete all "candidate dag runs" from list for this dag right now
self.log.info(
"dag cannot be started due to backfill max_active_runs constraint; "
"active_runs=%s max_active_runs=%s dag_id=%s run_id=%s",
active_runs,
backfill.max_active_runs,
dag_id,
run_id,
)
active_runs_of_dags[dag_run.dag_id] += 1
_update_state(dag, dag_run)
dag_run.notify_dagrun_state_changed()
continue
elif dag.max_active_runs:
if active_runs >= dag.max_active_runs:
# todo: delete all candidate dag runs for this dag from list right now
self.log.info(
"dag cannot be started due to dag max_active_runs constraint; "
"active_runs=%s max_active_runs=%s dag_id=%s run_id=%s",
active_runs,
dag_run.max_active_runs,
dag_run.dag_id,
dag_run.run_id,
)
continue
if span.is_recording():
span.add_event(
name="dag_run",
attributes={
"run_id": dag_run.run_id,
"dag_id": dag_run.dag_id,
"conf": str(dag_run.conf),
},
)
active_runs_of_dags[(dag_run.dag_id, backfill_id)] += 1
_update_state(dag, dag_run)
dag_run.notify_dagrun_state_changed()

@retry_db_transaction
def _schedule_all_dag_runs(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add backfill to dag run model.

Revision ID: c3389cd7793f
Revises: 0d9e73a75ee4
Create Date: 2024-09-21 07:52:29.869725

"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "c3389cd7793f"
down_revision = "0d9e73a75ee4"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"


def upgrade():
"""Apply Add backfill to dag run model."""
with op.batch_alter_table("dag_run", schema=None) as batch_op:
batch_op.add_column(sa.Column("backfill_id", sa.Integer(), nullable=True))
batch_op.create_foreign_key(
batch_op.f("dag_run_backfill_id_fkey"), "backfill", ["backfill_id"], ["id"]
)


def downgrade():
"""Unapply Add backfill to dag run model."""
with op.batch_alter_table("dag_run", schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f("dag_run_backfill_id_fkey"), type_="foreignkey")
batch_op.drop_column("backfill_id")
9 changes: 8 additions & 1 deletion airflow/models/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import TYPE_CHECKING

from sqlalchemy import Boolean, Column, ForeignKeyConstraint, Integer, UniqueConstraint, func, select, update
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, validates
from sqlalchemy_jsonfield import JSONField

from airflow.api_connexion.exceptions import Conflict, NotFound
Expand Down Expand Up @@ -113,6 +113,12 @@ class BackfillDagRun(Base):
),
)

@validates("sort_ordinal")
def validate_sort_ordinal(self, key, val):
if val < 1:
raise ValueError("sort_ordinal must be >= 1")
return val


def _create_backfill(
*,
Expand Down Expand Up @@ -175,6 +181,7 @@ def _create_backfill(
run_type=DagRunType.BACKFILL_JOB,
creating_job_id=None,
session=session,
backfill_id=br.id,
)
except Exception:
dag.log.exception(
Expand Down
9 changes: 8 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def _create_orm_dagrun(
dag_hash,
creating_job_id,
data_interval,
backfill_id,
session,
triggered_by,
):
Expand All @@ -321,6 +322,7 @@ def _create_orm_dagrun(
creating_job_id=creating_job_id,
data_interval=data_interval,
triggered_by=triggered_by,
backfill_id=backfill_id,
)
# Load defaults into the following two fields to ensure result can be serialized detached
run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id))))
Expand Down Expand Up @@ -2545,6 +2547,7 @@ def create_dagrun(
dag_hash: str | None = None,
creating_job_id: int | None = None,
data_interval: tuple[datetime, datetime] | None = None,
backfill_id: int | None = None,
):
"""
Create a dag run from this dag including the tasks associated with this dag.
Expand All @@ -2563,6 +2566,7 @@ def create_dagrun(
:param session: database session
:param dag_hash: Hash of Serialized DAG
:param data_interval: Data interval of the DagRun
:param backfill_id: id of the backfill run if one exists
"""
logical_date = timezone.coerce_datetime(execution_date)

Expand Down Expand Up @@ -2612,6 +2616,8 @@ def create_dagrun(
f"the configured pattern: '{regex}' or the built-in pattern: '{RUN_ID_REGEX}'"
)

# todo: AIP-78 add verification that if run type is backfill then we have a backfill id

# create a copy of params before validating
copied_params = copy.deepcopy(self.params)
copied_params.update(conf or {})
Expand All @@ -2629,6 +2635,7 @@ def create_dagrun(
run_type=run_type,
dag_hash=dag_hash,
creating_job_id=creating_job_id,
backfill_id=backfill_id,
data_interval=data_interval,
session=session,
triggered_by=triggered_by,
Expand Down Expand Up @@ -2947,7 +2954,7 @@ class DagModel(Base):
)

max_active_tasks = Column(Integer, nullable=False)
max_active_runs = Column(Integer, nullable=True)
max_active_runs = Column(Integer, nullable=True) # todo: should not be nullable if we have a default
max_consecutive_failed_dag_runs = Column(Integer, nullable=False)

has_task_concurrency_limits = Column(Boolean, nullable=False)
Expand Down
Loading