Skip to content

Commit 8161c1e

Browse files
authored
Optimization when there is no data to download (#301)
1 parent 518a1c3 commit 8161c1e

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

src/litdata/processing/data_processor.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,22 @@ def _is_path(input_dir: Optional[str], element: Any) -> bool:
379379
return os.path.isfile(element)
380380

381381

382+
class FakeQueue:
383+
"""This class enables us to replace multiprocessing Queue when not required and avoid serializing data."""
384+
385+
def __init__(self) -> None:
386+
self._items: List[Any] = []
387+
388+
def add_items(self, items: List[Any]) -> None:
389+
self._items.extend(items)
390+
391+
def get(self) -> None:
392+
try:
393+
return self._items.pop(0)
394+
except IndexError:
395+
return None
396+
397+
382398
class BaseWorker:
383399
def __init__(
384400
self,
@@ -422,7 +438,8 @@ def __init__(
422438
self.to_download_queues: List[Queue] = []
423439
self.to_upload_queues: List[Queue] = []
424440
self.stop_queue = stop_queue
425-
self.ready_to_process_queue: Queue = Queue()
441+
self.no_downloaders = self.input_dir.path is None or self.reader is not None
442+
self.ready_to_process_queue: Union[Queue, FakeQueue] = FakeQueue() if self.no_downloaders else Queue()
426443
self.remove_queue: Queue = Queue()
427444
self.progress_queue: Queue = progress_queue
428445
self.error_queue: Queue = error_queue
@@ -554,11 +571,12 @@ def _try_upload(self, data: Optional[Union[str, Tuple[str, str]]]) -> None:
554571
self.to_upload_queues[self._counter % self.num_uploaders].put(data)
555572

556573
def _collect_paths(self) -> None:
557-
if self.input_dir.path is None or self.reader is not None:
558-
for index in range(len(self.items)):
559-
self.ready_to_process_queue.put(index)
560-
for _ in range(self.num_downloaders):
561-
self.ready_to_process_queue.put(None)
574+
if self.no_downloaders:
575+
if isinstance(self.ready_to_process_queue, FakeQueue):
576+
self.ready_to_process_queue.add_items(list(range(len(self.items))))
577+
else:
578+
for index in range(len(self.items)):
579+
self.ready_to_process_queue.put(index)
562580
return
563581

564582
items = []
@@ -582,7 +600,11 @@ def _collect_paths(self) -> None:
582600
paths = []
583601
for index, path in indexed_paths.items():
584602
paths.append(path)
585-
if self.input_dir and not self.input_dir.path.startswith("/teamspace/studios/this_studio"):
603+
if (
604+
self.input_dir
605+
and isinstance(self.input_dir.path, str)
606+
and not self.input_dir.path.startswith("/teamspace/studios/this_studio")
607+
):
586608
path = path.replace(self.input_dir.path, self.cache_data_dir)
587609
flattened_item[index] = path
588610

@@ -593,7 +615,7 @@ def _collect_paths(self) -> None:
593615
self.items = items
594616

595617
def _start_downloaders(self) -> None:
596-
if self.input_dir.path is None or self.reader is not None:
618+
if self.no_downloaders:
597619
return
598620

599621
for _ in range(self.num_downloaders):

0 commit comments

Comments
 (0)