Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 59 additions & 110 deletions airflow/io/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,20 @@
from __future__ import annotations

import contextlib
import functools
import os
import shutil
import typing
from pathlib import PurePath
from typing import Any, Mapping
from urllib.parse import urlsplit

from fsspec.core import split_protocol
from fsspec.utils import stringify_path
from upath.implementations.cloud import CloudPath, _CloudAccessor
from upath.implementations.cloud import CloudPath
from upath.registry import get_upath_class

from airflow.io.store import attach
from airflow.io.utils.stat import stat_result

if typing.TYPE_CHECKING:
from urllib.parse import SplitResult

from fsspec import AbstractFileSystem


Expand All @@ -43,124 +39,68 @@
default = "file"


class _AirflowCloudAccessor(_CloudAccessor):
__slots__ = ("_store",)

def __init__(
self,
parsed_url: SplitResult | None,
conn_id: str | None = None,
**kwargs: typing.Any,
) -> None:
# warning: we are not calling super().__init__ here
# as it will try to create a new fs from a different
# set if registered filesystems
if parsed_url and parsed_url.scheme:
self._store = attach(parsed_url.scheme, conn_id)
else:
self._store = attach("file", conn_id)

@property
def _fs(self) -> AbstractFileSystem:
return self._store.fs

def __eq__(self, other):
return isinstance(other, _AirflowCloudAccessor) and self._store == other._store


class ObjectStoragePath(CloudPath):
"""A path-like object for object storage."""

_accessor: _AirflowCloudAccessor

__version__: typing.ClassVar[int] = 1

_default_accessor = _AirflowCloudAccessor
_protocol_dispatch = False

sep: typing.ClassVar[str] = "/"
root_marker: typing.ClassVar[str] = "/"

_bucket: str
_key: str
_protocol: str
_hash: int | None

__slots__ = (
"_bucket",
"_key",
"_conn_id",
"_protocol",
"_hash",
)

def __new__(
cls: type[PT],
*args: str | os.PathLike,
scheme: str | None = None,
conn_id: str | None = None,
**kwargs: typing.Any,
) -> PT:
args_list = list(args)

if args_list:
other = args_list.pop(0) or "."
else:
other = "."

if isinstance(other, PurePath):
_cls: typing.Any = type(other)
drv, root, parts = _cls._parse_args(args_list)
drv, root, parts = _cls._flavour.join_parsed_parts(
other._drv, # type: ignore[attr-defined]
other._root, # type: ignore[attr-defined]
other._parts, # type: ignore[attr-defined]
drv,
root,
parts, # type: ignore
)

_kwargs = getattr(other, "_kwargs", {})
_url = getattr(other, "_url", None)
other_kwargs = _kwargs.copy()
if _url and _url.scheme:
other_kwargs["url"] = _url
new_kwargs = _kwargs.copy()
new_kwargs.update(kwargs)

return _cls(_cls._format_parsed_parts(drv, root, parts, **other_kwargs), **new_kwargs)

url = stringify_path(other)
parsed_url: SplitResult = urlsplit(url)

if scheme: # allow override of protocol
parsed_url = parsed_url._replace(scheme=scheme)

if not parsed_url.path: # ensure path has root
parsed_url = parsed_url._replace(path="/")

if not parsed_url.scheme and not split_protocol(url)[0]:
args_list.insert(0, url)
else:
args_list.insert(0, parsed_url.path)
__slots__ = ("_hash_cached",)

@classmethod
def _transform_init_args(
cls,
args: tuple[str | os.PathLike, ...],
protocol: str,
storage_options: dict[str, Any],
) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]:
"""Extract conn_id from the URL and set it as a storage option."""
if args:
arg0 = args[0]
parsed_url = urlsplit(stringify_path(arg0))
userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
if have_info:
storage_options.setdefault("conn_id", userinfo or None)
parsed_url = parsed_url._replace(netloc=hostinfo)
args = (parsed_url.geturl(),) + args[1:]
protocol = protocol or parsed_url.scheme
return args, protocol, storage_options

# This matches the parsing logic in urllib.parse; see:
# https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75a/Lib/urllib/parse.py#L194-L203
userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
if have_info:
conn_id = conn_id or userinfo or None
parsed_url = parsed_url._replace(netloc=hostinfo)
@classmethod
def _parse_storage_options(
cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any]
) -> dict[str, Any]:
fs = attach(protocol or "file", conn_id=storage_options.get("conn_id")).fs
pth_storage_options = type(fs)._get_kwargs_from_urls(urlpath)
return {**pth_storage_options, **storage_options}

return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore
@classmethod
def _fs_factory(
cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any]
) -> AbstractFileSystem:
return attach(protocol or "file", storage_options.get("conn_id")).fs

@functools.lru_cache
def __hash__(self) -> int:
return hash(str(self))
self._hash_cached: int
try:
return self._hash_cached
except AttributeError:
self._hash_cached = hash(str(self))
return self._hash_cached

def __eq__(self, other: typing.Any) -> bool:
return self.samestore(other) and str(self) == str(other)

def samestore(self, other: typing.Any) -> bool:
return isinstance(other, ObjectStoragePath) and self._accessor == other._accessor
return (
isinstance(other, ObjectStoragePath)
and self.protocol == other.protocol
and self.storage_options.get("conn_id") == other.storage_options.get("conn_id")
)

@property
def container(self) -> str:
Expand All @@ -186,12 +126,17 @@ def key(self) -> str:
def namespace(self) -> str:
return f"{self.protocol}://{self.bucket}" if self.bucket else self.protocol

def open(self, mode="r", **kwargs):
"""Open the file pointed to by this path."""
kwargs.setdefault("block_size", kwargs.pop("buffering", None))
return self.fs.open(self.path, mode=mode, **kwargs)

def stat(self) -> stat_result: # type: ignore[override]
"""Call ``stat`` and return the result."""
return stat_result(
self._accessor.stat(self),
self.fs.stat(self.path),
protocol=self.protocol,
conn_id=self._accessor._store.conn_id,
conn_id=self.storage_options.get("conn_id"),
)

def samefile(self, other_path: typing.Any) -> bool:
Expand Down Expand Up @@ -368,7 +313,11 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs)
if path == self.path:
continue

src_obj = ObjectStoragePath(path, conn_id=self._accessor._store.conn_id)
src_obj = ObjectStoragePath(
path,
protocol=self.protocol,
conn_id=self.storage_options.get("conn_id"),
)

# skip directories, empty directories will not be created
if src_obj.is_dir():
Expand Down Expand Up @@ -401,7 +350,7 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs)
self.unlink()

def serialize(self) -> dict[str, typing.Any]:
_kwargs = self._kwargs.copy()
_kwargs = {**self.storage_options}
conn_id = _kwargs.pop("conn_id", None)

return {
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/common/io/xcom/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def serialize_value(
if not p.parent.exists():
p.parent.mkdir(parents=True, exist_ok=True)

with p.open("wb", compression=compression) as f:
with p.open(mode="wb", compression=compression) as f:
f.write(s_val)

return BaseXCom.serialize_value(str(p))
Expand All @@ -152,7 +152,7 @@ def deserialize_value(

try:
p = ObjectStoragePath(path) / XComObjectStoreBackend._get_key(data)
return json.load(p.open("rb", compression="infer"), cls=XComDecoder)
return json.load(p.open(mode="rb", compression="infer"), cls=XComDecoder)
except TypeError:
return data
except ValueError:
Expand Down
9 changes: 1 addition & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,7 @@ dependencies = [
# We should also remove "licenses/LICENSE-unicodecsv.txt" file when we remove this dependency
"unicodecsv>=0.14.1",
# The Universal Pathlib provides Pathlib-like interface for FSSPEC
# In 0.1. *It was not very well defined for extension, so the way how we use it for 0.1.*
# so we used a lot of private methods and attributes that were not defined in the interface
# an they are broken with version 0.2.0 which is much better suited for extension and supports
# Python 3.12. We should limit it, unti we migrate to 0.2.0
# See: https://github.com/fsspec/universal_pathlib/pull/173#issuecomment-1937090528
# This is prerequistite to make Airflow compatible with Python 3.12
# Tracked in https://github.com/apache/airflow/pull/36755
"universal-pathlib>=0.1.4,<0.2.0",
"universal-pathlib>=0.2.1",
# Werkzug 3 breaks Flask-Login 0.6.2, also connexion needs to be updated to >= 3.0
# we should remove this limitation when FAB supports Flask 2.3 and we migrate connexion to 3+
"werkzeug>=2.0,<3",
Expand Down
Loading