Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
668e5c6
Implement round robin and separate ventilator and result queue
arushi297 Jul 11, 2025
3b02944
Fixed and tested locally
arushi297 Jul 15, 2025
ff076ec
Implement the alternate design to fix race condition for some tests
arushi297 Jul 18, 2025
555548b
do code cleanup
arushi297 Jul 18, 2025
e3bccb6
Restore imports
arushi297 Jul 18, 2025
a96fa32
Restore imports
arushi297 Jul 18, 2025
6708b9d
fix queue size
arushi297 Jul 23, 2025
0ff55cd
Change petastorm release version format to fix failure due to setupto…
arushi297 Jul 24, 2025
80e6d82
Add constraint on setuptools version to prevent issue with new versions
arushi297 Jul 24, 2025
7ab26c4
Fix lint issues
arushi297 Jul 24, 2025
ba726f6
fix some more lint issues
arushi297 Jul 24, 2025
1676459
Add logs for testing
arushi297 Jul 27, 2025
4712391
[Test] Remove -Y flag to force fresh dataset generation
arushi297 Jul 27, 2025
08e98b6
Update failing test
arushi297 Jul 28, 2025
18a0709
Fix test_stop_when_result_queue_is_full expected queue size as per th…
arushi297 Jul 28, 2025
dc05685
Empty commit to trigger build
arushi297 Jul 28, 2025
bafe06f
Empty commit to trigger build
arushi297 Jul 28, 2025
410e07b
[Revert] Adding debug logs
arushi297 Aug 6, 2025
43e3555
[Revert] Adding debug logs
arushi297 Aug 6, 2025
a234e67
[Revert] Adding debug logs
arushi297 Aug 6, 2025
30de3b5
[Revert] restrict test runs
arushi297 Aug 6, 2025
8ce4b04
[Revert] fix logger import
arushi297 Aug 6, 2025
b214ca1
Revert back to enable all tests
arushi297 Aug 6, 2025
acbddb3
[Revert] change logs to print
arushi297 Aug 6, 2025
516aa09
[Revert] Modify debug logs
arushi297 Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions petastorm/arrow_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import hashlib
import operator
import logging

import numpy as np
import pandas as pd
Expand All @@ -26,6 +27,9 @@
from petastorm.workers_pool import EmptyResultError
from petastorm.workers_pool.worker_base import WorkerBase

# Initialize logger
logger = logging.getLogger(__name__)


class ArrowReaderWorkerResultsQueueReader(object):
def __init__(self):
Expand Down Expand Up @@ -91,6 +95,9 @@ class ArrowReaderWorker(WorkerBase):
def __init__(self, worker_id, publish_func, args):
super(ArrowReaderWorker, self).__init__(worker_id, publish_func, args)

# Add debug log in the constructor
print(f'DEBUG: Initializing ArrowReaderWorker with worker_id: {worker_id}')

self._filesystem = args[0]
self._dataset_path_or_paths = args[1]
self._schema = args[2]
Expand All @@ -101,7 +108,10 @@ def __init__(self, worker_id, publish_func, args):
self._transformed_schema = args[7]
self._arrow_filters = args[8]
self._shuffle_rows = args[9]
self._random_state = np.random.RandomState(seed=args[10])
self._random_seed = args[10]

# Initialize random number generator
self._rng = np.random.default_rng(self._random_seed)

if self._ngram:
raise NotImplementedError('ngrams are not supported by ArrowReaderWorker')
Expand All @@ -128,12 +138,18 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
:return:
"""

# Add debug log in the process method
print(f'DEBUG: Processing piece_index: {piece_index}')

if not self._dataset:
self._dataset = pq.ParquetDataset(
self._dataset_path_or_paths,
filesystem=self._filesystem,
validate_schema=False, filters=self._arrow_filters)

# Add debug log after dataset is initialized
print(f'DEBUG: ParquetDataset initialized with path: {self._dataset_path_or_paths}')

piece = self._split_pieces[piece_index]

# Create pyarrow file system
Expand All @@ -160,11 +176,16 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
path_str = self._dataset_path_or_paths
cache_key = '{}:{}:{}'.format(hashlib.md5(path_str.encode('utf-8')).hexdigest(),
piece.path, piece_index)

# Add debug log for cache key
print(f'DEBUG: Cache key generated: {cache_key}')

all_cols = self._local_cache.get(cache_key,
lambda: self._load_rows(parquet_file, piece, shuffle_row_drop_partition))

if all_cols:
self.publish_func(all_cols)
print(f'DEBUG: Published columns for piece_index: {piece_index}')

@staticmethod
def _check_shape_and_ravel(x, field):
Expand Down Expand Up @@ -289,9 +310,19 @@ def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_

# pyarrow would fail if we request a column names that the dataset is partitioned by
table = piece.read(columns=column_names - partition_names, partitions=self._dataset.partitions)

# Handle row shuffling based on shuffle_rows setting
if self._shuffle_rows:
indices = self._random_state.permutation(table.num_rows)
table = table.take(indices)
if self._random_seed is not None and self._random_seed != 0:
# Deterministic randomization: use provided seed
indices = self._rng.permutation(table.num_rows)
else:
# Non-deterministic randomization: use np.random directly
indices = np.random.permutation(table.num_rows)
else:
# Deterministic natural order: shuffle_rows=False
indices = np.arange(table.num_rows)
table = table.take(indices)

# Drop columns we did not explicitly request. This may happen when a table is partitioned. Besides columns
# requested, pyarrow will also return partition values. Having these unexpected fields will break some
Expand Down
17 changes: 11 additions & 6 deletions petastorm/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from petastorm.workers_pool.thread_pool import ThreadPool
from petastorm.workers_pool.ventilator import ConcurrentVentilator

# Initialize logger
logger = logging.getLogger(__name__)

# Ventilator guarantees that no more than workers + _VENTILATE_EXTRA_ROWGROUPS are processed at a moment by a
Expand Down Expand Up @@ -159,7 +160,7 @@ def make_reader(dataset_url,
'To read from a non-Petastorm Parquet store use make_batch_reader')

if reader_pool_type == 'thread':
reader_pool = ThreadPool(workers_count, results_queue_size)
reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed)
elif reader_pool_type == 'process':
if pyarrow_serialize:
warnings.warn("pyarrow_serializer was deprecated and will be removed in future versions. "
Expand Down Expand Up @@ -315,7 +316,7 @@ def make_batch_reader(dataset_url_or_urls,
raise ValueError('Unknown cache_type: {}'.format(cache_type))

if reader_pool_type == 'thread':
reader_pool = ThreadPool(workers_count, results_queue_size)
reader_pool = ThreadPool(workers_count, results_queue_size, shuffle_rows=shuffle_rows, seed=seed)
elif reader_pool_type == 'process':
serializer = ArrowTableSerializer()
reader_pool = ProcessPool(workers_count, serializer, zmq_copy_buffers=zmq_copy_buffers)
Expand Down Expand Up @@ -400,6 +401,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,
These will be applied when loading the parquet file with PyArrow. More information
here: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html
"""
print(f'DEBUG: Initializing Reader with dataset_path: {dataset_path}, num_epochs: {num_epochs}')
self.num_epochs = num_epochs

# 1. Open the parquet storage (dataset)
Expand Down Expand Up @@ -437,9 +439,11 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,
raise NotImplementedError('Using timestamp_overlap=False is not implemented with'
' shuffle_options.shuffle_row_drop_partitions > 1')

print(f'DEBUG: Reader initialized with schema_fields: {schema_fields}')

cache = cache or NullCache()

self._workers_pool = reader_pool or ThreadPool(10)
self._workers_pool = reader_pool or ThreadPool(10, shuffle_rows=shuffle_rows, seed=seed)

# Make a schema view (a view is a Unischema containing only a subset of fields
# Will raise an exception if invalid schema fields are in schema_fields
Expand Down Expand Up @@ -483,7 +487,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,
self.ngram, row_groups, cache, transform_spec,
self.schema, filters, shuffle_rows, seed),
ventilator=self.ventilator)
logger.debug('Workers pool started')
print('DEBUG: Workers pool started')

self.last_row_consumed = False
self.stopped = False
Expand Down Expand Up @@ -653,6 +657,7 @@ def _normalize_shuffle_options(shuffle_row_drop_partitions, dataset):

def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_drop_partitions,
num_epochs, worker_predicate, max_ventilation_queue_size, seed):
print(f'DEBUG: Creating ventilator with row_group_indexes: {row_group_indexes}')
items_to_ventilate = []
for piece_index in row_group_indexes:
for shuffle_row_drop_partition in range(shuffle_row_drop_partitions):
Expand All @@ -670,12 +675,12 @@ def _create_ventilator(self, row_group_indexes, shuffle_row_groups, shuffle_row_
random_seed=seed)

def stop(self):
"""Stops all worker threads/processes."""
print('DEBUG: Stopping Reader')
self._workers_pool.stop()
self.stopped = True

def join(self):
"""Joins all worker threads/processes. Will block until all worker workers have been fully terminated."""
print('DEBUG: Joining Reader')
self._workers_pool.join()

@property
Expand Down
11 changes: 8 additions & 3 deletions petastorm/tests/test_tf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def test_with_dataset_repeat(synthetic_dataset, reader_factory):
def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory):
""" Check if ``tf.data.Dataset``'s ``repeat`` works after ``tf.data.Dataset``'s ``cache``."""
epochs = 3
print(f"Starting test_with_dataset_repeat_after_cache with {epochs} epochs")
with reader_factory(synthetic_dataset.url, schema_fields=[TestSchema.id]) as reader:
dataset = make_petastorm_dataset(reader)
dataset = dataset.cache()
Expand All @@ -138,18 +139,22 @@ def test_with_dataset_repeat_after_cache(synthetic_dataset, reader_factory):
with tf.Session() as sess:
with pytest.warns(None):
# Expect no warnings since cache() is called before repeat()
for _ in range(epochs):
for epoch in range(epochs):
print(f"Starting epoch {epoch}")
actual_res = []
for _, _ in enumerate(synthetic_dataset.data):
for i, _ in enumerate(synthetic_dataset.data):
actual = sess.run(it_op)._asdict()
actual_res.append(actual["id"])
print(f"iteration: {i} {actual['id']}")
expected_res = list(range(len(synthetic_dataset.data)))
print(f"Epoch: {epoch} actual {sorted(actual_res)}, expected {expected_res}")
# sort dataset output since row_groups are shuffled from reader.
np.testing.assert_equal(sorted(actual_res), expected_res)

print(f"Completed epoch {epoch}")
# Exhausted all epochs. Fetching next value should trigger OutOfRangeError
with pytest.raises(tf.errors.OutOfRangeError):
sess.run(it_op)
print("Completed test_with_dataset_repeat_after_cache")


@pytest.mark.forked
Expand Down
8 changes: 5 additions & 3 deletions petastorm/workers_pool/tests/test_workers_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,17 @@ def test_stop_when_result_queue_is_full(self):
SLEEP_DELTA = 0.01
TIMEOUT = 20
QUEUE_SIZE = 2
WORKERS_COUNT = 10

pool = ThreadPool(10, results_queue_size=QUEUE_SIZE)
pool = ThreadPool(WORKERS_COUNT, results_queue_size=QUEUE_SIZE)
pool.start(WorkerIdGeneratingWorker)

for _ in range(100):
for _ in range(1000):
pool.ventilate()

expected_queue_size = WORKERS_COUNT * max(5, QUEUE_SIZE // WORKERS_COUNT)
cumulative_wait = 0
while pool.results_qsize() != QUEUE_SIZE:
while pool.results_qsize() != expected_queue_size:
time.sleep(SLEEP_DELTA)
cumulative_wait += SLEEP_DELTA
# Make sure we wait no longer than the timeout. Otherwise, something is very wrong
Expand Down
Loading
Loading