Skip to content

feat: Add support for multi-cloud using fsspec #469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ torch
lightning-utilities
filelock
numpy
boto3
requests
tifffile
fsspec
fsspec[s3] # aws s3
2 changes: 2 additions & 0 deletions requirements/extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ lightning-sdk==0.1.46 # Must be pinned to ensure compatibility
google-cloud-storage
polars
fsspec
fsspec[gs] # google cloud storage
fsspec[abfs] # azure blob
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,4 @@
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
_IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None))
_ENABLE_STATUS = bool(int(os.getenv("ENABLE_STATUS_REPORT", "0")))
_SUPPORTED_CLOUD_PROVIDERS = ("s3", "gs", "azure", "abfs")
164 changes: 86 additions & 78 deletions src/litdata/processing/data_processor.py

Large diffs are not rendered by default.

75 changes: 45 additions & 30 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch

from litdata import __version__
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS
from litdata.helpers import _check_version_and_prompt_upgrade
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, MapRecipe
from litdata.processing.readers import BaseReader
Expand All @@ -38,8 +38,8 @@
optimize_dns_context,
read_index_file_content,
)
from litdata.streaming.client import S3Client
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.downloader import copy_file_or_directory, upload_file_or_directory
from litdata.streaming.item_loader import BaseItemLoader
from litdata.streaming.resolver import (
Dir,
Expand All @@ -55,7 +55,7 @@

def _is_remote_file(path: str) -> bool:
obj = parse.urlparse(path)
return obj.scheme in ["s3", "gcs"]
return obj.scheme in _SUPPORTED_CLOUD_PROVIDERS


def _get_indexed_paths(data: Any) -> Dict[int, str]:
Expand Down Expand Up @@ -153,8 +153,15 @@ def __init__(
compression: Optional[str],
encryption: Optional[Encryption] = None,
existing_index: Optional[Dict[str, Any]] = None,
storage_options: Optional[Dict] = {},
):
super().__init__(chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression, encryption=encryption)
super().__init__(
chunk_size=chunk_size,
chunk_bytes=chunk_bytes,
compression=compression,
encryption=encryption,
storage_options=storage_options,
)
self._fn = fn
self._inputs = inputs
self.is_generator = False
Expand Down Expand Up @@ -202,6 +209,7 @@ def map(
reader: Optional[BaseReader] = None,
batch_size: Optional[int] = None,
start_method: Optional[str] = None,
storage_options: Optional[Dict] = {},
) -> None:
"""Maps a callable over a collection of inputs, possibly in a distributed way.

Expand All @@ -225,6 +233,7 @@ def map(
batch_size: Group the inputs into batches of batch_size length.
start_method: The start method used by python multiprocessing package. Default to spawn unless running
inside an interactive shell like Ipython.
storage_options: The storage options used by the cloud provider.
"""
_check_version_and_prompt_upgrade(__version__)

Expand Down Expand Up @@ -264,7 +273,7 @@ def map(
)

if error_when_not_empty:
_assert_dir_is_empty(_output_dir)
_assert_dir_is_empty(_output_dir, storage_options=storage_options)

if not isinstance(inputs, StreamingDataLoader):
input_dir = input_dir or _get_input_dir(inputs)
Expand All @@ -289,6 +298,7 @@ def map(
weights=weights,
reader=reader,
start_method=start_method,
storage_options=storage_options,
)
with optimize_dns_context(True):
return data_processor.run(LambdaMapRecipe(fn, inputs))
Expand Down Expand Up @@ -322,6 +332,7 @@ def optimize(
use_checkpoint: bool = False,
item_loader: Optional[BaseItemLoader] = None,
start_method: Optional[str] = None,
storage_options: Optional[Dict] = {},
) -> None:
"""This function converts a dataset into chunks, possibly in a distributed way.

Expand Down Expand Up @@ -356,6 +367,7 @@ def optimize(
the format in which the data is stored and optimized for loading.
start_method: The start method used by python multiprocessing package. Default to spawn unless running
inside an interactive shell like Ipython.
storage_options: The storage options used by the cloud provider.

"""
_check_version_and_prompt_upgrade(__version__)
Expand Down Expand Up @@ -411,7 +423,9 @@ def optimize(
"\n HINT: You can either use `/teamspace/s3_connections/...` or `/teamspace/datasets/...`."
)

_assert_dir_has_index_file(_output_dir, mode=mode, use_checkpoint=use_checkpoint)
_assert_dir_has_index_file(
_output_dir, mode=mode, use_checkpoint=use_checkpoint, storage_options=storage_options
)

if not isinstance(inputs, StreamingDataLoader):
resolved_dir = _resolve_dir(input_dir or _get_input_dir(inputs))
Expand All @@ -427,7 +441,9 @@ def optimize(
num_workers = num_workers or _get_default_num_workers()
state_dict = {rank: 0 for rank in range(num_workers)}

existing_index_file_content = read_index_file_content(_output_dir) if mode == "append" else None
existing_index_file_content = (
read_index_file_content(_output_dir, storage_options=storage_options) if mode == "append" else None
)

if existing_index_file_content is not None:
for chunk in existing_index_file_content["chunks"]:
Expand Down Expand Up @@ -461,6 +477,7 @@ def optimize(
compression=compression,
encryption=encryption,
existing_index=existing_index_file_content,
storage_options=storage_options,
)
)
return None
Expand Down Expand Up @@ -529,13 +546,19 @@ class CopyInfo:
new_filename: str


def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional[int] = os.cpu_count()) -> None:
def merge_datasets(
input_dirs: List[str],
output_dir: str,
max_workers: Optional[int] = os.cpu_count(),
storage_options: Optional[Dict] = {},
) -> None:
"""Enables to merge multiple existing optimized datasets into a single optimized dataset.

Args:
input_dirs: A list of directories pointing to the existing optimized datasets.
output_dir: The directory where the merged dataset would be stored.
max_workers: Number of workers for multithreading
storage_options: A dictionary of storage options to be passed to the fsspec library.

"""
if len(input_dirs) == 0:
Expand All @@ -551,12 +574,14 @@ def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional
if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs):
raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.")

input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs]
input_dirs_file_content = [
read_index_file_content(input_dir, storage_options=storage_options) for input_dir in resolved_input_dirs
]

if any(file_content is None for file_content in input_dirs_file_content):
raise ValueError("One of the provided input_dir doesn't have an index file.")

output_dir_file_content = read_index_file_content(resolved_output_dir)
output_dir_file_content = read_index_file_content(resolved_output_dir, storage_options=storage_options)

if output_dir_file_content is not None:
raise ValueError("The output_dir already contains an optimized dataset")
Expand Down Expand Up @@ -593,16 +618,16 @@ def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures: List[concurrent.futures.Future] = []
for copy_info in copy_infos:
future = executor.submit(_apply_copy, copy_info, resolved_output_dir)
future = executor.submit(_apply_copy, copy_info, resolved_output_dir, storage_options)
futures.append(future)

for future in _tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
future.result()

_save_index(index_json, resolved_output_dir)
_save_index(index_json, resolved_output_dir, storage_options)


def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None:
def _apply_copy(copy_info: CopyInfo, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None:
if output_dir.url is None and copy_info.input_dir.url is None:
assert copy_info.input_dir.path
assert output_dir.path
Expand All @@ -612,20 +637,15 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None:
shutil.copyfile(input_filepath, output_filepath)

elif output_dir.url and copy_info.input_dir.url:
input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename))
output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename))

s3 = S3Client()
s3.client.copy(
{"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")},
output_obj.netloc,
output_obj.path.lstrip("/"),
)
input_obj = os.path.join(copy_info.input_dir.url, copy_info.old_filename)
output_obj = os.path.join(output_dir.url, copy_info.new_filename)

copy_file_or_directory(input_obj, output_obj, storage_options=storage_options)
else:
raise NotImplementedError


def _save_index(index_json: Dict, output_dir: Dir) -> None:
def _save_index(index_json: Dict, output_dir: Dir, storage_options: Optional[Dict] = {}) -> None:
if output_dir.url is None:
assert output_dir.path
with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f:
Expand All @@ -636,11 +656,6 @@ def _save_index(index_json: Dict, output_dir: Dir) -> None:

f.flush()

obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME))

s3 = S3Client()
s3.client.upload_file(
f.name,
obj.netloc,
obj.path.lstrip("/"),
upload_file_or_directory(
f.name, os.path.join(output_dir.url, _INDEX_FILENAME), storage_options=storage_options
)
49 changes: 23 additions & 26 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from urllib import parse

import boto3
import botocore

from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _SUPPORTED_CLOUD_PROVIDERS
from litdata.streaming.cache import Dir
from litdata.streaming.downloader import download_file_or_directory


def _create_dataset(
Expand Down Expand Up @@ -183,7 +181,7 @@ def _get_work_dir() -> str:
return f"s3://{bucket_name}/projects/{project_id}/lightningapps/{app_id}/artifacts/{work_id}/content/"


def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]:
def read_index_file_content(output_dir: Dir, storage_options: Optional[Dict] = {}) -> Optional[Dict[str, Any]]:
"""Read the index file content."""
if not isinstance(output_dir, Dir):
raise ValueError("The provided output_dir should be a Dir object.")
Expand All @@ -201,27 +199,26 @@ def read_index_file_content(output_dir: Dir) -> Optional[Dict[str, Any]]:
# download the index file from s3, and read it
obj = parse.urlparse(output_dir.url)

if obj.scheme != "s3":
raise ValueError(f"The provided folder should start with s3://. Found {output_dir.path}.")

# TODO: Add support for all cloud providers
s3 = boto3.client("s3")

prefix = obj.path.lstrip("/").rstrip("/") + "/"
if obj.scheme not in _SUPPORTED_CLOUD_PROVIDERS:
raise ValueError(
f"The provided folder should start with {_SUPPORTED_CLOUD_PROVIDERS}. Found {output_dir.path}."
)

# Check the index file exists
try:
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
temp_file_name = temp_file.name
s3.download_file(obj.netloc, os.path.join(prefix, _INDEX_FILENAME), temp_file_name)
download_file_or_directory(
os.path.join(output_dir.url, _INDEX_FILENAME), temp_file_name, storage_options=storage_options
)
# Read data from the temporary file
with open(temp_file_name) as temp_file:
data = json.load(temp_file)
# Delete the temporary file
os.remove(temp_file_name)
return data
except botocore.exceptions.ClientError:
except Exception as _e:
return None


Expand Down Expand Up @@ -258,19 +255,19 @@ def remove_uuid_from_filename(filepath: str) -> str:
return filepath[:-38] + ".json"


def download_directory_from_S3(bucket_name: str, remote_directory_name: str, local_directory_name: str) -> str:
s3_resource = boto3.resource("s3")
bucket = s3_resource.Bucket(bucket_name)
# def download_directory_from_S3(bucket_name: str, remote_directory_name: str, local_directory_name: str) -> str:
# s3_resource = boto3.resource("s3")
# bucket = s3_resource.Bucket(bucket_name)

saved_file_dir = "."
# saved_file_dir = "."

for obj in bucket.objects.filter(Prefix=remote_directory_name):
local_filename = os.path.join(local_directory_name, obj.key)
# for obj in bucket.objects.filter(Prefix=remote_directory_name):
# local_filename = os.path.join(local_directory_name, obj.key)

if not os.path.exists(os.path.dirname(local_filename)):
os.makedirs(os.path.dirname(local_filename))
with open(local_filename, "wb") as f:
s3_resource.meta.client.download_fileobj(bucket_name, obj.key, f)
saved_file_dir = os.path.dirname(local_filename)
# if not os.path.exists(os.path.dirname(local_filename)):
# os.makedirs(os.path.dirname(local_filename))
# with open(local_filename, "wb") as f:
# s3_resource.meta.client.download_fileobj(bucket_name, obj.key, f)
# saved_file_dir = os.path.dirname(local_filename)

return saved_file_dir
# return saved_file_dir
Loading
Loading