Skip to content

feat: Add support for path in map fn #582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def prepare_item(self, item_metadata: Any) -> Any:
def map(
fn: Callable[[str, Any], None],
inputs: Union[Sequence[Any], StreamingDataLoader],
output_dir: Union[str, Dir],
input_dir: Optional[str] = None,
output_dir: Union[str, Path, Dir],
input_dir: Optional[Union[str, Path]] = None,
weights: Optional[List[int]] = None,
num_workers: Optional[int] = None,
fast_dev_run: Union[bool, int] = False,
Expand Down
39 changes: 27 additions & 12 deletions src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,28 @@ class Dir:
url: Optional[str] = None


def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir:
def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
if isinstance(dir_path, Dir):
return Dir(path=str(dir_path.path) if dir_path.path else None, url=str(dir_path.url) if dir_path.url else None)

if dir_path is None:
return Dir()

if not isinstance(dir_path, str):
raise ValueError(f"`dir_path` must be a `Dir` or a string, got: {dir_path}")
if not isinstance(dir_path, (str, Path)):
raise ValueError(f"`dir_path` must be either a string, Path, or Dir, got: {dir_path}")

assert isinstance(dir_path, str)
if isinstance(dir_path, str):
cloud_prefixes = ("s3://", "gs://", "azure://", "hf://")
if dir_path.startswith(cloud_prefixes):
return Dir(path=None, url=dir_path)

cloud_prefixes = ("s3://", "gs://", "azure://", "hf://")
if dir_path.startswith(cloud_prefixes):
return Dir(path=None, url=dir_path)
if dir_path.startswith("local:"):
return Dir(path=None, url=dir_path)

if dir_path.startswith("local:"):
return Dir(path=None, url=dir_path)

dir_path = _resolve_time_template(dir_path)
dir_path = _resolve_time_template(dir_path)

dir_path_absolute = str(Path(dir_path).absolute().resolve())
dir_path = str(dir_path) # Convert to string if it was a Path object

if dir_path_absolute.startswith("/teamspace/studios/this_studio"):
return Dir(path=dir_path_absolute, url=None)
Expand Down Expand Up @@ -345,7 +345,22 @@ def _get_lightning_cloud_url() -> str:


def _resolve_time_template(path: str) -> str:
match = re.search("^.*{%.*}$", path)
"""Resolves a datetime pattern in the given path string.

If the path contains a placeholder in the form `{%Y-%m-%d}`, it replaces it
with the current date/time formatted using the specified `strftime` pattern.

Example:
Input: "/logs/log_{%Y-%m-%d}.txt"
Output (on May 5, 2025): "/logs/log_2025-05-05.txt"

Args:
path (str): The file path containing an optional datetime placeholder.

Returns:
str: The path with the datetime placeholder replaced by the current timestamp.
"""
match = re.search("^.*{%.*}.*$", path)
if match is None:
return path

Expand Down
40 changes: 39 additions & 1 deletion tests/processing/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from PIL import Image

from litdata import StreamingDataset, merge_datasets, optimize, walk
from litdata import StreamingDataset, map, merge_datasets, optimize, walk
from litdata.processing.functions import _get_input_dir, _resolve_dir
from litdata.streaming.cache import Cache
from litdata.utilities.encryption import FernetEncryption, RSAEncryption
Expand Down Expand Up @@ -58,6 +58,44 @@ def test_get_input_dir_with_s3_path():
assert input_dir.url == "s3://my_bucket/my_folder"


def update_msg(file_path: Path, output_dir: Path):
with open(os.path.join(output_dir, file_path.name), "w") as f:
f.write("Bonjour!")


def test_map_with_path(tmpdir):
input_dir = Path(tmpdir) / "input_dir"
output_dir = Path(tmpdir) / "output_dir"

os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

for i in range(5):
filepath = os.path.join(input_dir, f"{i}.txt")
with open(filepath, "w") as f:
f.write("hello world!")

# read all files in the input directory, and assert it contains hello world!
for file in input_dir.iterdir():
with open(file) as f:
content = f.read()
assert content == "hello world!"

inputs = list(input_dir.iterdir()) # List all files in the directory

map(
fn=update_msg,
inputs=inputs,
output_dir=output_dir,
)

# read all files in the output directory, and assert it contains Bonjour!
for file in output_dir.iterdir():
with open(file) as f:
content = f.read()
assert content == "Bonjour!"


def compress(index):
return index, index**2

Expand Down
15 changes: 15 additions & 0 deletions tests/streaming/test_resolver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import sys
from pathlib import Path
from unittest import mock
Expand Down Expand Up @@ -388,3 +389,17 @@ def test_resolve_dir_absolute(tmp_path, monkeypatch):
link.symlink_to(src)
assert link.resolve() == src
assert resolver._resolve_dir(str(link)).path == str(src)


def test_resolve_time_template():
path_1 = "/logs/log_{%Y-%m}"
path_2 = "/logs/my_logfile"
path_3 = "/logs/log_{%Y-%m}/important"

current_datetime = datetime.datetime.now()
curr_year = current_datetime.year
curr_month = current_datetime.month

assert resolver._resolve_time_template(path_1) == f"/logs/log_{curr_year}-{curr_month:02d}"
assert resolver._resolve_time_template(path_2) == path_2
assert resolver._resolve_time_template(path_3) == f"/logs/log_{curr_year}-{curr_month:02d}/important"
Loading