Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 77 additions & 19 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,35 @@ def _get_normalized_scheme(uri: str) -> str:
return parsed.scheme.lower()


class _IdentifierValidator:
@staticmethod
def validate(kind: str, value: str | None, *, optional: bool = False) -> None:
if optional and value is None:
return
if not value:
raise ValueError(f"{kind} cannot be empty")
if len(value) > 3000:
raise ValueError(f"{kind} must be at most 3000 characters")
if value.isspace():
raise ValueError(f"{kind} cannot be just whitespace")
if not value.isascii():
raise ValueError(f"{kind} must only consist of ASCII characters")

def __call__(self, inst: Dataset | DatasetAlias, attribute: attr.Attribute, value: str | None) -> None:
self.validate(
f"{type(inst).__name__} {attribute.name}",
value,
optional=attribute.default is None,
)


def _sanitize_uri(uri: str) -> str:
"""
Sanitize a dataset URI.

This checks for URI validity, and normalizes the URI if needed. A fully
normalized URI is returned.
"""
if not uri:
raise ValueError("Dataset URI cannot be empty")
if uri.isspace():
raise ValueError("Dataset URI cannot be just whitespace")
if not uri.isascii():
raise ValueError("Dataset URI must only consist of ASCII characters")
parsed = urllib.parse.urlsplit(uri)
if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
return uri
Expand Down Expand Up @@ -133,10 +149,10 @@ def extract_event_key(value: str | Dataset | DatasetAlias) -> str:
"""
if isinstance(value, DatasetAlias):
return value.name

if isinstance(value, Dataset):
return value.uri
return _sanitize_uri(str(value))
_IdentifierValidator.validate("Dataset event key", uri := str(value))
return _sanitize_uri(uri)


@internal_api_call
Expand Down Expand Up @@ -210,16 +226,13 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
class DatasetAlias(BaseDataset):
"""A represeation of dataset alias which is used to create dataset during the runtime."""

name: str
name: str = attr.field(validator=_IdentifierValidator())

def __eq__(self, other: Any) -> bool:
if isinstance(other, DatasetAlias):
return self.name == other.name
return NotImplemented

def __hash__(self) -> int:
return hash(self.name)

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
"""
Iterate a dataset alias as dag dependency.
Expand All @@ -241,28 +254,73 @@ class DatasetAliasEvent(TypedDict):
dest_dataset_uri: str


@attr.define()
class NoComparison(ArithmeticError):
"""Exception for when two datasets cannot be compared directly."""

a: Dataset
b: Dataset

def __str__(self) -> str:
return f"Can not compare {self.a} and {self.b}"


@attr.define()
class Dataset(os.PathLike, BaseDataset):
"""A representation of data dependencies between workflows."""

name: str = attr.field(default=None, validator=_IdentifierValidator())
uri: str = attr.field(
default=None,
kw_only=True,
converter=_sanitize_uri,
validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
validator=_IdentifierValidator(),
)
extra: dict[str, Any] | None = None
extra: dict[str, Any] | None = attr.field(kw_only=True, default=None)

__version__: ClassVar[int] = 1

def __attrs_post_init__(self) -> None:
if self.name is None and self.uri is None:
raise TypeError("Dataset requires either name or URI")

def __fspath__(self) -> str:
return self.uri

def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
"""
Check equality of two datasets.

Since either *name* or *uri* is required, and we ensure integrity when
DAG files are parsed, we only need to consider the following combos:

* Both datasets have name and uri defined: Both fields must match.
* One dataset have only one field (name or uri) defined: The field
defined by both must match.
* Both datasets have the same one field defined: The field must match.
* Either dataset has the other field defined (e.g. *self* defines only
*name*, but *other* only *uri*): The two cannot be reliably compared,
and (a subclass of) *ArithmeticError* is raised.

In the last case, we can still check dataset equality by querying the
database. We do not do here though since that has too much performance
implication. The call site should consider the possibility instead.

However, since *Dataset* objects created from the meta-database (e.g.
those in the task execution context) would have both concrete name and
URI values filled by the DAG parser. Non-comparability only happens if
the user accesses the dataset objects that aren't created from the
database, say globally in a DAG file. This is discouraged anyway.
"""
if not isinstance(other, self.__class__):
return NotImplemented
if self.name is not None and other.name is not None:
if self.uri is None or other.uri is None:
return self.name == other.name
return self.name == other.name and self.uri == other.uri
if self.uri is not None and other.uri is not None:
return self.uri == other.uri
return NotImplemented

def __hash__(self) -> int:
return hash(self.uri)
raise NoComparison(self, other)

@property
def normalized_uri(self) -> str | None:
Expand Down
10 changes: 7 additions & 3 deletions airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.listeners.listener import get_listener_manager
from airflow.models.dagbag import DagPriorityParsingRequest
from airflow.models.dataset import (
Expand All @@ -43,6 +42,7 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.datasets import Dataset
from airflow.models.dag import DagModel
from airflow.models.taskinstance import TaskInstance

Expand All @@ -58,14 +58,18 @@ class DatasetManager(LoggingMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None:
def create_datasets(self, dataset_models: Iterable[DatasetModel], session: Session) -> None:
"""Create new datasets."""
for dataset_model in dataset_models:
if not dataset_model.name:
dataset_model.name = dataset_model.uri
elif not dataset_model.uri:
dataset_model.uri = dataset_model.name
session.add(dataset_model)
session.flush()

for dataset_model in dataset_models:
self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra))
self.notify_dataset_created(dataset=dataset_model.as_public())

@classmethod
@internal_api_call
Expand Down
122 changes: 122 additions & 0 deletions airflow/datasets/references.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Dataset reference objects.

These are intermediate representations of DAG- and task-level references to
Dataset and DatasetAlias. These are meant only for Airflow internals so the DAG
processor can collect information from DAGs without the "full picture", which is
only available when it updates the database.
"""

from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
from airflow.datasets import Dataset, DatasetAlias
from airflow.models.dataset import (
DagScheduleDatasetAliasReference,
DagScheduleDatasetReference,
DatasetAliasModel,
DatasetModel,
TaskOutletDatasetReference,
)

DatasetReference = Union["DatasetNameReference", "DatasetURIReference"]

DatasetOrAliasReference = Union[DatasetReference, "DatasetAliasReference"]


def create_dag_dataset_reference(source: Dataset) -> DatasetReference:
"""Create reference to a dataset."""
if source.name:
return DatasetNameReference(source.name)
return DatasetURIReference(source.uri)


def create_dag_dataset_alias_reference(source: DatasetAlias) -> DatasetAliasReference:
"""Create reference to a dataset or dataset alias."""
return DatasetAliasReference(source.name)


@dataclasses.dataclass
class DatasetNameReference:
"""Reference to a dataset by name."""

name: str

def __hash__(self) -> int:
return hash((self.__class__.__name__, self.name))


@dataclasses.dataclass
class DatasetURIReference:
"""Reference to a dataset by URI."""

uri: str

def __hash__(self) -> int:
return hash((self.__class__.__name__, self.uri))


@dataclasses.dataclass
class DatasetAliasReference:
"""Reference to a dataset alias."""

name: str

def __hash__(self) -> int:
return hash((self.__class__.__name__, self.name))


def resolve_dag_schedule_reference(
ref: DatasetOrAliasReference,
*,
dag_id: str,
dataset_names: dict[str, DatasetModel],
dataset_uris: dict[str, DatasetModel],
alias_names: dict[str, DatasetAliasModel],
) -> DagScheduleDatasetReference | DagScheduleDatasetAliasReference:
"""Create database representation from DAG-level references."""
from airflow.models.dataset import DagScheduleDatasetAliasReference, DagScheduleDatasetReference

if isinstance(ref, DatasetNameReference):
return DagScheduleDatasetReference(dataset_id=dataset_names[ref.name].id, dag_id=dag_id)
elif isinstance(ref, DatasetURIReference):
return DagScheduleDatasetReference(dataset_id=dataset_uris[ref.uri].id, dag_id=dag_id)
return DagScheduleDatasetAliasReference(alias_id=alias_names[ref.name].id, dag_id=dag_id)


def resolve_task_outlet_reference(
ref: DatasetReference,
*,
dag_id: str,
task_id: str,
dataset_names: dict[str, DatasetModel],
dataset_uris: dict[str, DatasetModel],
) -> TaskOutletDatasetReference:
"""Create database representation from task-level references."""
from airflow.models.dataset import TaskOutletDatasetReference

if isinstance(ref, DatasetURIReference):
dataset = dataset_uris[ref.uri]
else:
dataset = dataset_names[ref.name]
return TaskOutletDatasetReference(dataset_id=dataset.id, dag_id=dag_id, task_id=task_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add name field to DatasetModel.

This also renames two indexes. Index names are scoped to the entire database.
Airflow generally includes the table's name to manually scope the index, but
``idx_uri_unique`` (on DatasetModel) and ``idx_name_unique`` (on
DatasetAliasModel) do not do this. They are renamed here so we can create a
unique index on DatasetModel as well.

Revision ID: 0d9e73a75ee4
Revises: 0bfc26bc256e
Create Date: 2024-08-13 09:45:32.213222
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op
from sqlalchemy.orm import Session

# revision identifiers, used by Alembic.
revision = "0d9e73a75ee4"
down_revision = "0bfc26bc256e"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"

_NAME_COLUMN_TYPE = sa.String(length=3000).with_variant(
sa.String(length=3000, collation="latin1_general_cs"),
dialect_name="mysql",
)


def upgrade():
# Fix index name on DatasetAlias.
with op.batch_alter_table("dataset_alias", schema=None) as batch_op:
batch_op.drop_index("idx_name_unique")
batch_op.create_index("idx_dataset_alias_name_unique", ["name"], unique=True)
# Fix index name (of 'uri') on Dataset.
# Add 'name' column. Set it to nullable for now.
with op.batch_alter_table("dataset", schema=None) as batch_op:
batch_op.drop_index("idx_uri_unique")
batch_op.create_index("idx_dataset_uri_unique", ["uri"], unique=True)
batch_op.add_column(sa.Column("name", _NAME_COLUMN_TYPE))
# Fill name from uri column.
Session(bind=op.get_bind()).execute(sa.text("update dataset set name=uri"))
# Set the name column non-nullable.
# Now with values in there, we can create the unique constraint and index.
with op.batch_alter_table("dataset", schema=None) as batch_op:
batch_op.alter_column("name", existing_type=_NAME_COLUMN_TYPE, nullable=False)
batch_op.create_index("idx_dataset_name_unique", ["name"], unique=True)


def downgrade():
with op.batch_alter_table("dataset", schema=None) as batch_op:
batch_op.drop_index("idx_dataset_name_unique")
batch_op.drop_column("name")
batch_op.drop_index("idx_dataset_uri_unique")
batch_op.create_index("idx_uri_unique", ["uri"], unique=True)
with op.batch_alter_table("dataset_alias", schema=None) as batch_op:
batch_op.drop_index("idx_dataset_alias_name_unique")
batch_op.create_index("idx_name_unique", ["name"], unique=True)
Loading