@@ -81,10 +81,10 @@ def test_parquet_index_write(
81
81
82
82
for i , _ds in enumerate (ds ):
83
83
idx = i % 5
84
- assert len (_ds ) == 3
85
- assert _ds [0 ] == pq_data ["name" ][idx ]
86
- assert _ds [1 ] == pq_data ["weight" ][idx ]
87
- assert _ds [2 ] == pq_data ["height" ][idx ]
84
+ assert isinstance (_ds , dict )
85
+ assert _ds ["name" ] == pq_data ["name" ][idx ]
86
+ assert _ds ["weight" ] == pq_data ["weight" ][idx ]
87
+ assert _ds ["height" ] == pq_data ["height" ][idx ]
88
88
89
89
90
90
@pytest .mark .skipif (condition = sys .platform == "win32" , reason = "Fails on windows and test gets cancelled" )
@@ -168,7 +168,9 @@ def test_get_parquet_indexer_cls(pq_url, cls, expectation, monkeypatch, fsspec_m
168
168
@pytest .mark .usefixtures ("clean_pq_index_cache" )
169
169
@patch ("litdata.utilities.parquet._HF_HUB_AVAILABLE" , True )
170
170
@patch ("litdata.streaming.downloader._HF_HUB_AVAILABLE" , True )
171
- def test_stream_hf_parquet_dataset (monkeypatch , huggingface_hub_fs_mock , pq_data ):
171
+ @pytest .mark .parametrize (("pre_load_chunk" ), [False , True ])
172
+ @pytest .mark .parametrize (("low_memory" ), [False , True ])
173
+ def test_stream_hf_parquet_dataset (monkeypatch , huggingface_hub_fs_mock , pq_data , pre_load_chunk , low_memory ):
172
174
hf_url = "hf://datasets/some_org/some_repo/some_path"
173
175
174
176
# Test case 1: Invalid item_loader
@@ -180,27 +182,18 @@ def test_stream_hf_parquet_dataset(monkeypatch, huggingface_hub_fs_mock, pq_data
180
182
assert len (ds ) == 25 # 5 datasets for 5 loops
181
183
for i , _ds in enumerate (ds ):
182
184
idx = i % 5
183
- assert len (_ds ) == 3
184
- assert _ds [0 ] == pq_data ["name" ][idx ]
185
- assert _ds [1 ] == pq_data ["weight" ][idx ]
186
- assert _ds [2 ] == pq_data ["height" ][idx ]
187
-
188
- # Test case 3: Streaming with ParquetLoader as item_loader and low_memory=False
189
- ds = StreamingDataset (hf_url , item_loader = ParquetLoader (low_memory = False ))
190
- assert len (ds ) == 25
191
- for i , _ds in enumerate (ds ):
192
- idx = i % 5
193
- assert len (_ds ) == 3
194
- assert _ds [0 ] == pq_data ["name" ][idx ]
195
- assert _ds [1 ] == pq_data ["weight" ][idx ]
196
- assert _ds [2 ] == pq_data ["height" ][idx ]
197
-
198
- # Test case 4: Streaming with ParquetLoader and low_memory=True
199
- ds = StreamingDataset (hf_url , item_loader = ParquetLoader (low_memory = True ))
185
+ assert isinstance (_ds , dict )
186
+ assert _ds ["name" ] == pq_data ["name" ][idx ]
187
+ assert _ds ["weight" ] == pq_data ["weight" ][idx ]
188
+ assert _ds ["height" ] == pq_data ["height" ][idx ]
189
+
190
+ # Test case 3: Streaming with passing item_loader
191
+ print ("pre_load_chunk" , pre_load_chunk , "low_memory" , low_memory )
192
+ ds = StreamingDataset (hf_url , item_loader = ParquetLoader (pre_load_chunk , low_memory ))
200
193
assert len (ds ) == 25
201
194
for i , _ds in enumerate (ds ):
202
195
idx = i % 5
203
- assert len (_ds ) == 3
204
- assert _ds [0 ] == pq_data ["name" ][idx ]
205
- assert _ds [1 ] == pq_data ["weight" ][idx ]
206
- assert _ds [2 ] == pq_data ["height" ][idx ]
196
+ assert isinstance (_ds , dict )
197
+ assert _ds ["name" ] == pq_data ["name" ][idx ]
198
+ assert _ds ["weight" ] == pq_data ["weight" ][idx ]
199
+ assert _ds ["height" ] == pq_data ["height" ][idx ]
0 commit comments