Skip to content

example async download directory #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 82 additions & 13 deletions src/entitysdk/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Identifiable SDK client."""

import asyncio
import io
import logging
import os
from pathlib import Path
from typing import Any, cast
Expand All @@ -20,10 +22,12 @@
from entitysdk.util import (
build_api_url,
create_intermediate_directories,
get_request_headers,
validate_filename_extension_consistency,
)
from entitysdk.utils.asset import filter_assets

L = logging.getLogger(__name__)

class Client:
"""Client for entitysdk."""
Expand Down Expand Up @@ -348,13 +352,14 @@ def download_directory(
project_context: ProjectContext | None = None,
ignore_directory_name: bool = False,
) -> list[Path]:
"""List directory existing entity's endpoint from a directory path."""
"""Download a directory to directory path."""
output_path = Path(output_path)

if output_path.exists() and output_path.is_file():
raise EntitySDKError(f"{output_path} exists and is a file")
output_path.mkdir(parents=True, exist_ok=True)

token = self._token_manager.get_token()
context = self._optional_user_context(override_context=project_context)

asset = None
Expand All @@ -370,7 +375,7 @@ def download_directory(
entity_type=Asset,
project_context=context,
http_client=self._http_client,
token=self._token_manager.get_token(),
token=token,
)

output_path /= asset.path
Expand All @@ -382,21 +387,85 @@ def download_directory(
project_context=project_context,
)

paths = []
for path in contents.files:
paths.append(
self.download_file(
entity_id=entity_id,
entity_type=entity_type,
asset_id=asset if asset else asset_id,
output_path=output_path / path,
asset_path=path,
project_context=context,
)
asset_endpoint = route.get_assets_endpoint(
api_url=self.api_url,
entity_type=entity_type,
entity_id=entity_id,
asset_id=asset_id
)

paths = [
self.download_file(
entity_id=entity_id,
entity_type=entity_type,
asset_id=asset if asset else asset_id,
output_path=output_path / path,
asset_path=path,
project_context=context,
)
for path in contents.files
]

return paths

async def async_download_directory(
self,
*,
entity_id: ID,
entity_type: type[Identifiable],
asset_id: ID,
output_path: os.PathLike,
project_context: ProjectContext | None = None,
ignore_directory_name: bool = False,
) -> list[Path]:
"""List directory existing entity's endpoint from a directory path."""
contents = self.list_directory(
entity_id=entity_id,
entity_type=entity_type,
asset_id=asset_id,
project_context=project_context,
)

asset_endpoint = route.get_assets_endpoint(
api_url=self.api_url,
entity_type=entity_type,
entity_id=entity_id,
asset_id=asset_id
)

url = f"{asset_endpoint}/download"


context = self._optional_user_context(override_context=project_context)
token = self._token_manager.get_token()
headers = get_request_headers(project_context=context, token=token)

max_concurrent = 8
semaphore = asyncio.Semaphore(max_concurrent)
async def download_file(client, base_url, asset_path, path):
path.parent.mkdir(parents=True, exist_ok=True)
async with semaphore:
async with client.stream('GET',
url,
headers=headers,
params={"asset_path": str(asset_path)},
follow_redirects=True,
) as response:
response.raise_for_status()
with path.open('wb') as fd:
async for chunk in response.aiter_bytes():
fd.write(chunk)
return path

client = httpx.AsyncClient()
tasks = [
download_file(client, url, path, output_path / path)
for path in contents.files
]

results = await asyncio.gather(*tasks)
return results

def download_content(
self,
*,
Expand Down
21 changes: 16 additions & 5 deletions src/entitysdk/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
from entitysdk.types import DeploymentEnvironment


def get_request_headers(
*,
project_context: ProjectContext | None = None,
token: str,
):
"""Return a dictionary with the required headers for a request."""
headers = {"Authorization": f"Bearer {token}"}

if project_context:
headers["project-id"] = str(project_context.project_id)
headers["virtual-lab-id"] = str(project_context.virtual_lab_id)

return headers


def make_db_api_request(
url: str,
*,
Expand All @@ -30,11 +45,7 @@ def make_db_api_request(
if http_client is None:
http_client = httpx.Client()

headers = {"Authorization": f"Bearer {token}"}

if project_context:
headers["project-id"] = str(project_context.project_id)
headers["virtual-lab-id"] = str(project_context.virtual_lab_id)
headers = get_request_headers(project_context=project_context, token=token)

try:
response = http_client.request(
Expand Down