Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from airflow._shared.observability.metrics.stats import Stats
from airflow._shared.timezones import timezone
from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun as DRDataModel, TIRunContext
from airflow.assets.evaluation import AssetEvaluator
from airflow.callbacks.callback_requests import (
DagCallbackRequest,
DagRunContext,
Expand All @@ -65,6 +66,7 @@
AssetWatcherModel,
DagScheduleAssetAliasReference,
DagScheduleAssetReference,
PartitionedAssetKeyLog,
TaskInletAssetReference,
TaskOutletAssetReference,
)
Expand All @@ -82,6 +84,7 @@
from airflow.models.team import Team
from airflow.models.trigger import TRIGGER_FAIL_REPR, Trigger, TriggerFailureReason
from airflow.observability.trace import DebugTrace, Trace, add_debug_span
from airflow.serialization.definitions.assets import SerializedAssetUniqueKey
from airflow.serialization.definitions.notset import NOTSET
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.timetables.simple import AssetTriggeredTimetable
Expand Down Expand Up @@ -1686,18 +1689,40 @@ def _do_scheduling(self, session: Session) -> int:

def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[str]:
partition_dag_ids: set[str] = set()
apdrs: Iterable[AssetPartitionDagRun] = session.scalars(

evaluator = AssetEvaluator(session)
for apdr in session.scalars(
select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None))
)
for apdr in apdrs:
):
if TYPE_CHECKING:
assert apdr.target_dag_id
partition_dag_ids.add(apdr.target_dag_id)
dag = _get_current_dag(dag_id=apdr.target_dag_id, session=session)
if not dag:

if not (dag := _get_current_dag(dag_id=apdr.target_dag_id, session=session)):
self.log.error("Dag '%s' not found in serialized_dag table", apdr.target_dag_id)
continue

asset_models = session.scalars(
select(AssetModel).where(
exists(
select(1).where(
PartitionedAssetKeyLog.asset_id == AssetModel.id,
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
)
)
)
)
statuses: dict[SerializedAssetUniqueKey, bool] = {
SerializedAssetUniqueKey.from_asset(a): True for a in asset_models
}
# todo: AIP-76 so, this basically works when we only require one partition from each asset to be there
# but, we ultimately need rollup ability
# that is, we need to ensure that whenever it is many -> one partitions, then we need to ensure
# that all the required keys are there
# one way to do this would be just to figure out what the count should be
if not evaluator.run(dag.timetable.asset_condition, statuses=statuses):
continue

partition_dag_ids.add(apdr.target_dag_id)
run_after = timezone.utcnow()
dag_run = dag.create_dagrun(
run_id=DagRun.generate_run_id(
Expand Down
Loading