Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/asset_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from datetime import datetime

from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse


class DagRunAssetReference(StrictBaseModel):
"""DagRun serializer for asset responses."""

run_id: str
dag_id: str
logical_date: datetime | None
start_date: datetime
end_date: datetime | None
state: str
data_interval_start: datetime | None
data_interval_end: datetime | None


class AssetEventResponse(BaseModel):
"""Asset event schema with fields that are needed for Runtime."""

id: int
timestamp: datetime
extra: dict | None = None

asset: AssetResponse
created_dagruns: list[DagRunAssetReference]

source_task_id: str | None = None
source_dag_id: str | None = None
source_run_id: str | None = None
source_map_index: int = -1


class AssetEventsResponse(BaseModel):
"""Collection of AssetEventResponse."""

asset_events: list[AssetEventResponse]
2 changes: 2 additions & 0 deletions airflow/api_fastapi/execution_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.routes import (
asset_events,
assets,
connections,
health,
Expand All @@ -28,6 +29,7 @@

execution_api_router = AirflowRouter()
execution_api_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
execution_api_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
execution_api_router.include_router(health.router, tags=["Health"])
execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
Expand Down
111 changes: 111 additions & 0 deletions airflow/api_fastapi/execution_api/routes/asset_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Annotated

from fastapi import HTTPException, Query, status
from sqlalchemy import and_, select

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
from airflow.api_fastapi.execution_api.datamodels.asset_event import (
AssetEventResponse,
AssetEventsResponse,
)
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel

# TODO: Add dependency on JWT token
router = AirflowRouter(
responses={
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)


def _get_asset_events_through_sql_clauses(
*, join_clause, where_clause, session: SessionDep
) -> AssetEventsResponse:
asset_events = session.scalars(
select(AssetEvent).join(join_clause).where(where_clause).order_by(AssetEvent.timestamp)
)
return AssetEventsResponse.model_validate(
{
"asset_events": [
AssetEventResponse(
id=event.id,
timestamp=event.timestamp,
extra=event.extra,
asset=AssetResponse(
name=event.asset.name,
uri=event.asset.uri,
group=event.asset.group,
extra=event.asset.extra,
),
created_dagruns=event.created_dagruns,
source_task_id=event.source_task_id,
source_dag_id=event.source_dag_id,
source_run_id=event.source_run_id,
source_map_index=event.source_map_index,
)
for event in asset_events
]
}
)


@router.get("/by-asset")
def get_asset_event_by_asset_name_uri(
name: Annotated[str | None, Query(description="The name of the Asset")],
uri: Annotated[str | None, Query(description="The URI of the Asset")],
session: SessionDep,
) -> AssetEventsResponse:
if name and uri:
where_clause = and_(AssetModel.name == name, AssetModel.uri == uri)
elif uri:
where_clause = and_(AssetModel.uri == uri, AssetModel.active.has())
elif name:
where_clause = and_(AssetModel.name == name, AssetModel.active.has())
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Missing parameter",
"message": "name and uri cannot both be None",
},
)

return _get_asset_events_through_sql_clauses(
join_clause=AssetEvent.asset,
where_clause=where_clause,
session=session,
)


@router.get("/by-asset-alias")
def get_asset_event_by_asset_alias(
name: Annotated[str, Query(description="The name of the Asset Alias")],
session: SessionDep,
) -> AssetEventsResponse:
return _get_asset_events_through_sql_clauses(
join_clause=AssetEvent.source_aliases,
where_clause=(AssetAliasModel.name == name),
session=session,
)
4 changes: 2 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.param import process_params
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.sdk.execution_time.context import InletEventsAccessors
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
Expand All @@ -119,7 +120,6 @@
from airflow.utils.context import (
ConnectionAccessor,
Context,
InletEventsAccessors,
OutletEventAccessors,
VariableAccessor,
context_get_outlet_events,
Expand Down Expand Up @@ -973,7 +973,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]:
context.update(
{
"outlet_events": OutletEventAccessors(),
"inlet_events": InletEventsAccessors(task.inlets, session=session),
"inlet_events": InletEventsAccessors(task.inlets),
"macros": macros,
"params": validated_params,
"prev_data_interval_start_success": get_prev_data_interval_start_success(),
Expand Down
125 changes: 2 additions & 123 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,50 +21,29 @@

from collections.abc import (
Container,
Iterator,
Mapping,
)
from typing import (
TYPE_CHECKING,
Any,
Union,
cast,
)

import attrs
from sqlalchemy import and_, select
from sqlalchemy import select

from airflow.models.asset import (
AssetAliasModel,
AssetEvent,
AssetModel,
fetch_active_assets_by_name,
fetch_active_assets_by_uri,
)
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasUniqueKey,
AssetNameRef,
AssetRef,
AssetUniqueKey,
AssetUriRef,
)
from airflow.sdk.definitions.context import Context
from airflow.sdk.execution_time.context import (
ConnectionAccessor as ConnectionAccessorSDK,
OutletEventAccessors as OutletEventAccessorsSDK,
VariableAccessor as VariableAccessorSDK,
)
from airflow.utils.db import LazySelectSequence
from airflow.utils.session import create_session
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from sqlalchemy.engine import Row
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import Select, TextClause

from airflow.sdk.definitions.asset import Asset
from airflow.sdk.types import OutletEventAccessorsProtocol

# NOTE: Please keep this in sync with the following:
Expand Down Expand Up @@ -170,106 +149,6 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset
return asset.to_public()


class LazyAssetEventSelectSequence(LazySelectSequence[AssetEvent]):
"""
List-like interface to lazily access AssetEvent rows.

:meta private:
"""

@staticmethod
def _rebuild_select(stmt: TextClause) -> Select:
return select(AssetEvent).from_statement(stmt)

@staticmethod
def _process_row(row: Row) -> AssetEvent:
return row[0]


@attrs.define(init=False)
class InletEventsAccessors(Mapping[Union[int, Asset, AssetAlias, AssetRef], LazyAssetEventSelectSequence]):
"""
Lazy mapping for inlet asset events accessors.

:meta private:
"""

_inlets: list[Any]
_assets: dict[AssetUniqueKey, Asset]
_asset_aliases: dict[AssetAliasUniqueKey, AssetAlias]
_session: Session

def __init__(self, inlets: list, *, session: Session) -> None:
self._inlets = inlets
self._session = session
self._assets = {}
self._asset_aliases = {}

_asset_ref_names: list[str] = []
_asset_ref_uris: list[str] = []
for inlet in inlets:
if isinstance(inlet, Asset):
self._assets[AssetUniqueKey.from_asset(inlet)] = inlet
elif isinstance(inlet, AssetAlias):
self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(inlet)] = inlet
elif isinstance(inlet, AssetNameRef):
_asset_ref_names.append(inlet.name)
elif isinstance(inlet, AssetUriRef):
_asset_ref_uris.append(inlet.uri)

if _asset_ref_names:
for _, asset in fetch_active_assets_by_name(_asset_ref_names, self._session).items():
self._assets[AssetUniqueKey.from_asset(asset)] = asset
if _asset_ref_uris:
for _, asset in fetch_active_assets_by_uri(_asset_ref_uris, self._session).items():
self._assets[AssetUniqueKey.from_asset(asset)] = asset

def __iter__(self) -> Iterator[Asset | AssetAlias]:
return iter(self._inlets)

def __len__(self) -> int:
return len(self._inlets)

def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> LazyAssetEventSelectSequence:
if isinstance(key, int): # Support index access; it's easier for trivial cases.
obj = self._inlets[key]
if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
raise IndexError(key)
else:
obj = key

if isinstance(obj, Asset):
asset = self._assets[AssetUniqueKey.from_asset(obj)]
join_clause = AssetEvent.asset
where_clause = and_(AssetModel.name == asset.name, AssetModel.uri == asset.uri)
elif isinstance(obj, AssetAlias):
asset_alias = self._asset_aliases[AssetAliasUniqueKey.from_asset_alias(obj)]
join_clause = AssetEvent.source_aliases
where_clause = AssetAliasModel.name == asset_alias.name
elif isinstance(obj, AssetNameRef):
try:
asset = next(a for k, a in self._assets.items() if k.name == obj.name)
except StopIteration:
raise KeyError(obj) from None
join_clause = AssetEvent.asset
where_clause = and_(AssetModel.name == asset.name, AssetModel.active.has())
elif isinstance(obj, AssetUriRef):
try:
asset = next(a for k, a in self._assets.items() if k.uri == obj.uri)
except StopIteration:
raise KeyError(obj) from None
join_clause = AssetEvent.asset
where_clause = and_(AssetModel.uri == asset.uri, AssetModel.active.has())
else:
raise ValueError(key)

return LazyAssetEventSelectSequence.from_select(
select(AssetEvent).join(join_clause).where(where_clause),
order_by=[AssetEvent.timestamp],
session=self._session,
)


def context_merge(context: Context, *args: Any, **kwargs: Any) -> None:
"""
Merge parameters into an existing context.
Expand Down
Loading