Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fs] support hfs.ls on a bucket #14176

Merged
merged 14 commits into from
Feb 6, 2024
7 changes: 5 additions & 2 deletions hail/python/hail/backend/local_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Union, Tuple, List
from contextlib import ExitStack
import os
import sys

Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(
gcs_requester_pays_project: Optional[str] = None,
gcs_requester_pays_buckets: Optional[str] = None,
):
self._exit_stack = ExitStack()
assert gcs_requester_pays_project is not None or gcs_requester_pays_buckets is None

spark_home = find_spark_home()
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(
die_on_exit=True,
)
self._gateway = JavaGateway(gateway_parameters=GatewayParameters(port=port, auto_convert=True))
self._exit_stack.callback(self._gateway.shutdown)

hail_package = getattr(self._gateway.jvm, 'is').hail

Expand All @@ -75,7 +78,7 @@ def __init__(

super(LocalBackend, self).__init__(self._gateway.jvm, jbackend, jhc)

self._fs = RouterFS()
self._fs = self._exit_stack.enter_context(RouterFS())
self._logger = None

self._initialize_flags({})
Expand Down Expand Up @@ -108,7 +111,7 @@ def register_ir_function(

def stop(self):
super().stop()
self._gateway.shutdown()
self._exit_stack.close()
uninstall_exception_handler()

@property
Expand Down
10 changes: 7 additions & 3 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ async def create(
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
gcs_bucket_allow_list: Optional[List[str]] = None,
):
async_exit_stack = AsyncExitStack()
billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None)
if billing_project is None:
raise ValueError(
Expand All @@ -221,9 +222,11 @@ async def create(
gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration},
gcs_bucket_allow_list=gcs_bucket_allow_list,
)
async_exit_stack.push_async_callback(async_fs.close)
sync_fs = RouterFS(async_fs)
if batch_client is None:
batch_client = await BatchClient.create(billing_project, _token=credentials_token)
async_exit_stack.push_async_callback(batch_client.close)
batch_attributes: Dict[str, str] = dict()
remote_tmpdir = get_remote_tmpdir('ServiceBackend', remote_tmpdir=remote_tmpdir)

Expand Down Expand Up @@ -288,6 +291,7 @@ async def create(
worker_cores=worker_cores,
worker_memory=worker_memory,
regions=regions,
async_exit_stack=async_exit_stack,
)
sb._initialize_flags(flags)
return sb
Expand All @@ -308,6 +312,7 @@ def __init__(
worker_cores: Optional[Union[int, str]],
worker_memory: Optional[str],
regions: List[str],
async_exit_stack: AsyncExitStack,
):
super(ServiceBackend, self).__init__()
self.billing_project = billing_project
Expand All @@ -329,6 +334,7 @@ def __init__(
self.regions = regions

self._batch: Batch = self._create_batch()
self._async_exit_stack = async_exit_stack

def _create_batch(self) -> Batch:
return self._batch_client.create_batch(attributes=self.batch_attributes)
Expand Down Expand Up @@ -362,9 +368,7 @@ def stop(self):
hail_event_loop().run_until_complete(self._stop())

async def _stop(self):
async with AsyncExitStack() as stack:
stack.push_async_callback(self._async_fs.close)
stack.push_async_callback(self._batch_client.close)
await self._async_exit_stack.aclose()
self.functions = []
self._registered_ir_function_names = set()

Expand Down
47 changes: 32 additions & 15 deletions hail/python/hailtop/aiocloud/aioaws/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AsyncFSURL,
MultiPartCreate,
FileAndDirectoryError,
IsABucketError,
)
from hailtop.aiotools.fs.exceptions import UnexpectedEOFError
from hailtop.aiotools.fs.stream import (
Expand Down Expand Up @@ -325,6 +326,9 @@ def __init__(self, bucket: str, path: str):
self._bucket = bucket
self._path = path

def __repr__(self):
return f'S3AsyncFSURL({self._bucket}, {self._path})'

@property
def bucket_parts(self) -> List[str]:
return [self._bucket]
Expand All @@ -344,6 +348,9 @@ def scheme(self) -> str:
def with_path(self, path) -> 'S3AsyncFSURL':
return S3AsyncFSURL(self._bucket, path)

def with_root_path(self) -> 'S3AsyncFSURL':
return self.with_path('')

def __str__(self) -> str:
return f's3://{self._bucket}/{self._path}'

Expand Down Expand Up @@ -399,8 +406,11 @@ def valid_url(url: str) -> bool:
return url.startswith('s3://')

@staticmethod
def parse_url(url: str) -> S3AsyncFSURL:
return S3AsyncFSURL(*S3AsyncFS.get_bucket_and_name(url))
def parse_url(url: str, *, error_if_bucket: bool = False) -> S3AsyncFSURL:
fsurl = S3AsyncFSURL(*S3AsyncFS.get_bucket_and_name(url))
if error_if_bucket and fsurl._path == '':
raise IsABucketError
return fsurl

@staticmethod
def get_bucket_and_name(url: str) -> Tuple[str, str]:
Expand All @@ -423,22 +433,24 @@ def get_bucket_and_name(url: str) -> Tuple[str, str]:
return (bucket, name)

async def open(self, url: str) -> ReadableStream:
bucket, name = self.get_bucket_and_name(url)
fsurl = self.parse_url(url, error_if_bucket=True)
try:
resp = await blocking_to_async(self._thread_pool, self._s3.get_object, Bucket=bucket, Key=name)
resp = await blocking_to_async(
self._thread_pool, self._s3.get_object, Bucket=fsurl._bucket, Key=fsurl._path
)
return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body']))
except self._s3.exceptions.NoSuchKey as e:
raise FileNotFoundError(url) from e

async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream:
bucket, name = self.get_bucket_and_name(url)
fsurl = self.parse_url(url, error_if_bucket=True)
range_str = f'bytes={start}-'
if length is not None:
assert length >= 1
range_str += str(start + length - 1)
try:
resp = await blocking_to_async(
self._thread_pool, self._s3.get_object, Bucket=bucket, Key=name, Range=range_str
self._thread_pool, self._s3.get_object, Bucket=fsurl._bucket, Key=fsurl._path, Range=range_str
)
return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body']))
except self._s3.exceptions.NoSuchKey as e:
Expand Down Expand Up @@ -489,12 +501,12 @@ async def create(self, url: str, *, retry_writes: bool = True) -> S3CreateManage
# interface. This has the disadvantage that the read must
# complete before the write can begin (unlike the current
# code, that copies 128MB parts in 256KB chunks).
bucket, name = self.get_bucket_and_name(url)
return S3CreateManager(self, bucket, name)
fsurl = self.parse_url(url, error_if_bucket=True)
return S3CreateManager(self, fsurl._bucket, fsurl._path)

async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate:
bucket, name = self.get_bucket_and_name(url)
return S3MultiPartCreate(sema, self, bucket, name, num_parts)
fsurl = self.parse_url(url, error_if_bucket=True)
return S3MultiPartCreate(sema, self, fsurl._bucket, fsurl._path, num_parts)

async def mkdir(self, url: str) -> None:
pass
Expand All @@ -503,9 +515,11 @@ async def makedirs(self, url: str, exist_ok: bool = False) -> None:
pass

async def statfile(self, url: str) -> FileStatus:
bucket, name = self.get_bucket_and_name(url)
fsurl = self.parse_url(url, error_if_bucket=True)
try:
resp = await blocking_to_async(self._thread_pool, self._s3.head_object, Bucket=bucket, Key=name)
resp = await blocking_to_async(
self._thread_pool, self._s3.head_object, Bucket=fsurl._bucket, Key=fsurl._path
)
return S3HeadObjectFileStatus(resp, url)
except botocore.exceptions.ClientError as e:
if e.response['ResponseMetadata']['HTTPStatusCode'] == 404:
Expand Down Expand Up @@ -579,8 +593,10 @@ async def staturl(self, url: str) -> str:
return await self._staturl_parallel_isfile_isdir(url)

async def isfile(self, url: str) -> bool:
bucket, name = self.get_bucket_and_name(url)
if name == '':
return False
try:
bucket, name = self.get_bucket_and_name(url)
await blocking_to_async(self._thread_pool, self._s3.head_object, Bucket=bucket, Key=name)
return True
except botocore.exceptions.ClientError as e:
Expand All @@ -589,6 +605,7 @@ async def isfile(self, url: str) -> bool:
raise e

async def isdir(self, url: str) -> bool:
self.parse_url(url, error_if_bucket=True)
try:
async for _ in await self.listfiles(url, recursive=True):
return True
Expand All @@ -597,9 +614,9 @@ async def isdir(self, url: str) -> bool:
return False

async def remove(self, url: str) -> None:
fsurl = self.parse_url(url, error_if_bucket=True)
try:
bucket, name = self.get_bucket_and_name(url)
await blocking_to_async(self._thread_pool, self._s3.delete_object, Bucket=bucket, Key=name)
await blocking_to_async(self._thread_pool, self._s3.delete_object, Bucket=fsurl._bucket, Key=fsurl._path)
except self._s3.exceptions.NoSuchKey as e:
raise FileNotFoundError(url) from e

Expand Down
31 changes: 24 additions & 7 deletions hail/python/hailtop/aiocloud/aioazure/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FileStatus,
FileAndDirectoryError,
UnexpectedEOFError,
IsABucketError,
)

from .credentials import AzureCredentials
Expand Down Expand Up @@ -298,6 +299,9 @@ def __init__(self, account: str, container: str, path: str, query: Optional[str]
self._path = path
self._query = query

def __repr__(self):
return f'AzureAsyncFSURL({self._account}, {self._container}, {self._path}, {self._query})'

@property
def bucket_parts(self) -> List[str]:
return [self._account, self._container]
Expand Down Expand Up @@ -326,6 +330,9 @@ def base(self) -> str:
def with_path(self, path) -> 'AzureAsyncFSURL':
return self.__class__(self._account, self._container, path, self._query)

def with_root_path(self) -> 'AzureAsyncFSURL':
return self.with_path('')

def __str__(self) -> str:
return self.base if not self._query else f'{self.base}?{self._query}'

Expand Down Expand Up @@ -440,7 +447,14 @@ async def generate_sas_token(
return token

@staticmethod
def parse_url(url: str) -> AzureAsyncFSURL:
def parse_url(url: str, *, error_if_bucket: bool = False) -> AzureAsyncFSURL:
fsurl = AzureAsyncFS._parse_url(url)
if error_if_bucket and fsurl._path == '':
raise IsABucketError
return fsurl

@staticmethod
def _parse_url(url: str) -> AzureAsyncFSURL:
colon_index = url.find(':')
if colon_index == -1:
raise ValueError(f'invalid URL: {url}')
Expand Down Expand Up @@ -513,21 +527,23 @@ def get_container_client(self, url: AzureAsyncFSURL) -> ContainerClient:

@handle_public_access_error
async def open(self, url: str) -> ReadableStream:
parsed_url = self.parse_url(url, error_if_bucket=True)
if not await self.exists(url):
raise FileNotFoundError
client = self.get_blob_client(self.parse_url(url))
client = self.get_blob_client(parsed_url)
return AzureReadableStream(client, url)

@handle_public_access_error
async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream:
assert length is None or length >= 1
if not await self.exists(url):
raise FileNotFoundError
client = self.get_blob_client(self.parse_url(url))
client = self.get_blob_client(self.parse_url(url, error_if_bucket=True))
return AzureReadableStream(client, url, offset=start, length=length)

async def create(self, url: str, *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]: # pylint: disable=unused-argument
return AzureCreateManager(self.get_blob_client(self.parse_url(url)))
parsed_url = self.parse_url(url, error_if_bucket=True)
return AzureCreateManager(self.get_blob_client(parsed_url))

async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate:
client = self.get_blob_client(self.parse_url(url))
Expand All @@ -545,7 +561,7 @@ async def isfile(self, url: str) -> bool:

@handle_public_access_error
async def isdir(self, url: str) -> bool:
fs_url = self.parse_url(url)
fs_url = self.parse_url(url, error_if_bucket=True)
assert not fs_url.path or fs_url.path.endswith('/'), fs_url.path
client = self.get_container_client(fs_url)
async for _ in client.walk_blobs(name_starts_with=fs_url.path, include=['metadata'], delimiter='/'):
Expand All @@ -560,8 +576,8 @@ async def makedirs(self, url: str, exist_ok: bool = False) -> None:

@handle_public_access_error
async def statfile(self, url: str) -> FileStatus:
parsed_url = self.parse_url(url, error_if_bucket=True)
try:
parsed_url = self.parse_url(url)
blob_props = await self.get_blob_client(parsed_url).get_blob_properties()
return AzureFileStatus(blob_props, parsed_url)
except azure.core.exceptions.ResourceNotFoundError as e:
Expand Down Expand Up @@ -639,7 +655,8 @@ async def staturl(self, url: str) -> str:

async def remove(self, url: str) -> None:
try:
await self.get_blob_client(self.parse_url(url)).delete_blob()
parsed_url = self.parse_url(url, error_if_bucket=True)
await self.get_blob_client(parsed_url).delete_blob()
except azure.core.exceptions.ResourceNotFoundError as e:
raise FileNotFoundError(url) from e

Expand Down
Loading
Loading