Skip to content

[WIP] Improve debugging of DDP hanging #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ files = [
# This section is for folders with "-" as they are not valid python modules
exclude = [
"src/litdata/utilities/_pytree.py",
"src/litdata/streaming/item_loader.py",
"src/litdata/utilities/breakpoint.py",
]
install_types = "True"
non_interactive = "True"
Expand Down
2 changes: 2 additions & 0 deletions src/litdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import TokensLoader
from litdata.utilities.breakpoint import breakpoint
from litdata.utilities.train_test_split import train_test_split

__all__ = [
Expand All @@ -30,6 +31,7 @@
"walk",
"train_test_split",
"merge_datasets",
"breakpoint",
]
if RequirementCache("lightning_sdk"):
from lightning_sdk import Machine # noqa: F401
Expand Down
4 changes: 4 additions & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
_TQDM_AVAILABLE = RequirementCache("tqdm")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
_DEBUG = bool(int(os.getenv("DEBUG", "1")))

_MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120"))
_FORCE_DOWNLOAD_TIME = int(os.getenv("FORCE_DOWNLOAD_TIME", "30"))

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
6 changes: 3 additions & 3 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def _create_cache(self) -> None:
self.cache_data_dir = _get_cache_data_dir()
self.cache_chunks_dir = _get_cache_dir()

if isinstance(self.data_recipe, DataTransformRecipe):
if isinstance(self.data_recipe, MapRecipe):
return

self.cache = Cache(
Expand Down Expand Up @@ -737,7 +737,7 @@ def _handle_data_transform_recipe(self, index: int) -> None:
item_data = self.data_recipe.prepare_item(item, str(output_dir), len(self.items) - 1 == index)
if item_data is not None:
raise ValueError(
"When using a `DataTransformRecipe`, the `prepare_item` shouldn't return anything."
"When using a `MapRecipe`, the `prepare_item` shouldn't return anything."
" Simply store your files under the output_dir."
)
filepaths = []
Expand Down Expand Up @@ -902,7 +902,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
self._upload_index(output_dir, cache_dir, 1, None)


class DataTransformRecipe(DataRecipe):
class MapRecipe(DataRecipe):
@abstractmethod
def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
"""Return the structure of your data.
Expand Down
6 changes: 3 additions & 3 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from litdata import __version__
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO
from litdata.helpers import _check_version_and_prompt_upgrade
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, MapRecipe
from litdata.processing.readers import BaseReader
from litdata.processing.utilities import (
_get_work_dir,
Expand Down Expand Up @@ -100,7 +100,7 @@ def _get_default_num_workers() -> int:
return os.cpu_count() or 1


class LambdaDataTransformRecipe(DataTransformRecipe):
class LambdaMapRecipe(MapRecipe):
def __init__(self, fn: Callable[[str, Any], None], inputs: Union[Sequence[Any], StreamingDataLoader]):
super().__init__()
self._fn = fn
Expand Down Expand Up @@ -291,7 +291,7 @@ def map(
start_method=start_method,
)
with optimize_dns_context(True):
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
return data_processor.run(LambdaMapRecipe(fn, inputs))
return _execute(
f"litdata-map-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}",
num_nodes,
Expand Down
34 changes: 31 additions & 3 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
from collections import namedtuple
from copy import deepcopy
from io import BytesIO, FileIO
from time import sleep
from multiprocessing import Queue
from time import sleep, time
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.constants import _FORCE_DOWNLOAD_TIME, _MAX_WAIT_TIME, _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import PyTree, tree_unflatten
from litdata.utilities.encryption import Encryption, EncryptionLevel
Expand All @@ -40,20 +41,26 @@ def setup(
chunks: List,
serializers: Dict[str, Serializer],
region_of_interest: Optional[List[Tuple[int, int]]] = None,
force_download_queue: Optional[Queue] = None,
) -> None:
self._config = config
self._chunks = chunks
self._serializers = {**serializers}
self._data_format = self._config["data_format"]
self._shift_idx = len(self._data_format) * 4
self.region_of_interest = region_of_interest
self._force_download_queue = force_download_queue

# setup the serializers on restart
for data_format in self._data_format:
serializer = deepcopy(self._serializers[self._data_format_to_key(data_format)])
serializer.setup(data_format)
self._serializers[data_format] = serializer

def force_download(self, chunk_index: int) -> None:
if self._force_download_queue:
self._force_download_queue.put(chunk_index)

@functools.lru_cache(maxsize=128)
def _data_format_to_key(self, data_format: str) -> str:
if ":" in data_format:
Expand Down Expand Up @@ -103,6 +110,7 @@ class PyTreeLoader(BaseItemLoader):
"""The Pytree Loader is the default loader of the Cache object."""

def __init__(self) -> None:
super().__init__()
self._chunk_filepaths: Dict[str, bool] = {}
self._decrypted_chunks: Dict[int, bytes] = {}

Expand Down Expand Up @@ -141,10 +149,20 @@ def load_item_from_chunk(
if chunk_filepath not in self._chunk_filepaths:
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes

start_time = time()
requested_force_download = False

while not exists:
sleep(0.1)
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes

if not requested_force_download and (time() - start_time) > _FORCE_DOWNLOAD_TIME:
self.force_download(chunk_index)
requested_force_download = True

if (time() - start_time) > _MAX_WAIT_TIME:
raise FileNotFoundError(f"The {chunk_filepath} hasn't been found.")

self._chunk_filepaths[chunk_filepath] = True

if self._config.get("encryption"):
Expand Down Expand Up @@ -347,9 +365,19 @@ def load_item_from_chunk(
if chunk_filepath not in self._chunk_filepaths:
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > filesize_bytes

start_time = time()
requested_force_download = False

while not exists:
sleep(0.1)
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > filesize_bytes
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes

if not requested_force_download and (time() - start_time) > _FORCE_DOWNLOAD_TIME:
self.force_download(chunk_index)
requested_force_download = True

if (time() - start_time) > _MAX_WAIT_TIME:
raise FileNotFoundError(f"The {chunk_filepath} hasn't been found.")

self._chunk_filepaths[chunk_filepath] = True

Expand Down
40 changes: 39 additions & 1 deletion src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from threading import Event, Thread
from typing import Any, Dict, List, Optional, Tuple, Union

from litdata.constants import _DEBUG
from litdata.streaming.config import ChunksConfig, Interval
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from litdata.streaming.sampler import ChunkedIndex
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
distributed_env: _DistributedEnv,
max_cache_size: Optional[int] = None,
max_pre_download: int = 2,
rank: Optional[int] = None,
) -> None:
super().__init__(daemon=True)
self._config = config
Expand All @@ -65,6 +67,11 @@ def __init__(
self._to_delete_queue: Queue = Queue()
self._force_stop_event = Event()

# TODO: Find a real fix to this problem
self._force_download_queue: Queue = Queue()

self._rank = rank

# Check whether a dataset slice fits on the node
num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes
self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False
Expand All @@ -84,8 +91,12 @@ def _apply_delete(self, chunk_index: int) -> None:
"""Inform the item loader of the chunk to delete."""
if self._config.can_delete(chunk_index):
chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]

self._item_loader.delete(chunk_index, chunk_filepath)

if _DEBUG:
print(f"Deleted {chunk_filepath} by {self._rank}")

try:
locak_chunk_path = chunk_filepath + ".lock"
if os.path.exists(locak_chunk_path):
Expand Down Expand Up @@ -130,11 +141,31 @@ def _pre_load_chunk(self, chunk_index: int) -> None:
chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
self._item_loader.pre_load_chunk(chunk_index, chunk_filepath)

def _force_download(self) -> None:
chunk_index = _get_from_queue(self._force_download_queue)
if chunk_index is not None:
if _DEBUG:
chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
print(f"Requested force download for {chunk_filepath} by {self._rank}")

self._config.download_chunk_from_index(chunk_index)

# Preload item if possible to gain some time but only
# if this is one of the pre-downloaded chunk
if self._pre_download_counter > 0:
self._pre_load_chunk(chunk_index)

# Avoid downloading too many chunks in advance at the risk of over using the disk space
self._pre_download_counter += 1

def run(self) -> None:
while True:
if self._force_stop_event.is_set():
self._has_exited = True
return

self._force_download()

if self._pre_download_counter < self._max_pre_download:
chunk_index = _get_from_queue(self._to_download_queue)
if chunk_index == _END_TOKEN:
Expand Down Expand Up @@ -266,8 +297,15 @@ def read(self, index: ChunkedIndex) -> Any:
# Create and start the prepare chunks thread
if self._prepare_thread is None and self._config:
self._prepare_thread = PrepareChunksThread(
self._config, self._item_loader, self._distributed_env, self._max_cache_size, self._max_pre_download
self._config,
self._item_loader,
self._distributed_env,
self._max_cache_size,
self._max_pre_download,
self._rank,
)
# Attach the force download queue
self._item_loader._force_download_queue = self._prepare_thread._force_download_queue # type: ignore
self._prepare_thread.start()
if index.chunk_indexes:
self._prepare_thread.download(index.chunk_indexes)
Expand Down
37 changes: 37 additions & 0 deletions src/litdata/utilities/breakpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import multiprocessing
import os
import pdb
import sys

_stdin = [None]
_stdin_lock = multiprocessing.Lock()
try:
_stdin_fd = sys.stdin.fileno()
except Exception:
_stdin_fd = None


# Taken from https://github.com/facebookresearch/metaseq/blob/main/metaseq/pdb.py
class MPPdb(pdb.Pdb):
"""A Pdb wrapper that works in a multiprocessing environment."""

def __init__(self) -> None:
pdb.Pdb.__init__(self, nosigint=True)

def _cmdloop(self) -> None:
stdin_back = sys.stdin
with _stdin_lock:
try:
if _stdin_fd is not None:
if not _stdin[0]:
_stdin[0] = os.fdopen(_stdin_fd)
sys.stdin = _stdin[0]
self.cmdloop()
finally:
sys.stdin = stdin_back


# This breakpoint can be used within the LitData workers.
def breakpoint() -> None:
pdb = MPPdb()
pdb.set_trace(sys._getframe().f_back)
18 changes: 16 additions & 2 deletions src/litdata/utilities/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def train_test_split(

# we need subsampled chunk filenames, original chunk file, and subsampled_roi

dummy_streaming_dataset = deepcopy(streaming_dataset)
dummy_streaming_dataset = deepcopy_dataset(streaming_dataset)
dummy_subsampled_chunk_filename = dummy_streaming_dataset.subsampled_files
dummy_subsampled_roi = dummy_streaming_dataset.region_of_interest
subsampled_chunks: List[Dict[str, Any]] = []
Expand All @@ -65,7 +65,7 @@ def train_test_split(
else:
raise ValueError("Couldn't load original chunk file.")

new_datasets = [deepcopy(streaming_dataset) for _ in splits]
new_datasets = [deepcopy_dataset(streaming_dataset) for _ in splits]

dataset_length = sum([my_roi[1] - my_roi[0] for my_roi in dummy_subsampled_roi])

Expand Down Expand Up @@ -94,3 +94,17 @@ def train_test_split(
dummy_subsampled_roi = left_roi

return new_datasets


def deepcopy_dataset(dataset: Any) -> Any:
has_cache = dataset.cache is not None
if has_cache:
original_prepare_thread = dataset.cache._reader._prepare_thread
original_force_download_queue = dataset.cache._reader._item_loader._force_download_queue
dataset.cache._reader._prepare_thread = None
dataset.cache._reader._item_loader._force_download_queue = None
copied_dataset = deepcopy(dataset)
if has_cache:
dataset.cache._reader._prepare_thread = original_prepare_thread
dataset.cache._reader._item_loader._force_download_queue = original_force_download_queue
return copied_dataset
10 changes: 5 additions & 5 deletions tests/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from litdata.processing.data_processor import (
DataChunkRecipe,
DataProcessor,
DataTransformRecipe,
MapRecipe,
_download_data_target,
_get_item_filesizes,
_is_path,
Expand All @@ -29,7 +29,7 @@
_wait_for_disk_usage_higher_than_threshold,
_wait_for_file_to_exist,
)
from litdata.processing.functions import LambdaDataTransformRecipe, map, optimize
from litdata.processing.functions import LambdaMapRecipe, map, optimize
from litdata.streaming import StreamingDataLoader, StreamingDataset, resolver
from litdata.streaming.cache import Cache, Dir

Expand Down Expand Up @@ -581,7 +581,7 @@ def test_data_processsor_nlp(tmpdir, monkeypatch):
data_processor_more_wokers.run(TextTokenizeRecipe(chunk_size=1024 * 11))


class ImageResizeRecipe(DataTransformRecipe):
class ImageResizeRecipe(MapRecipe):
def prepare_structure(self, input_dir: str):
filepaths = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
return [filepath for filepath in filepaths if os.path.isfile(filepath)]
Expand Down Expand Up @@ -836,7 +836,7 @@ def fn(output_dir, item, device):
assert device == "cuda:2"
called = True

data_recipe = LambdaDataTransformRecipe(fn, range(1))
data_recipe = LambdaMapRecipe(fn, range(1))

data_recipe.prepare_item(1, "", False)
assert called
Expand All @@ -857,7 +857,7 @@ def __call__(self, item, output_dir, device):
assert device == "cuda:2"
called = True

data_recipe = LambdaDataTransformRecipe(Transform(), range(1))
data_recipe = LambdaMapRecipe(Transform(), range(1))
data_recipe.prepare_item(1, "", False)
assert called

Expand Down
Loading
Loading