@@ -1185,73 +1185,81 @@ def sort(r):
11851185 assert r1 == ds .take ()
11861186
11871187
1188- def test_iter_batches_grid (ray_start_regular_shared ):
1188+ @pytest .mark .parametrize (
1189+ "block_sizes,batch_size,drop_last" ,
1190+ [
1191+ # Single block, batch smaller than block, keep partial
1192+ ([10 ], 3 , False ),
1193+ # Single block, batch smaller than block, drop partial
1194+ ([10 ], 3 , True ),
1195+ # Single block, exact division
1196+ ([10 ], 5 , False ),
1197+ # Multiple equal-sized blocks, batch doesn't divide evenly, keep partial
1198+ ([5 , 5 , 5 ], 7 , False ),
1199+ # Multiple equal-sized blocks, batch doesn't divide evenly, drop partial
1200+ ([5 , 5 , 5 ], 7 , True ),
1201+ # Multiple unequal-sized blocks, keep partial
1202+ ([1 , 5 , 10 ], 4 , False ),
1203+ # Multiple unequal-sized blocks, drop partial
1204+ ([1 , 5 , 10 ], 4 , True ),
1205+ # Edge case: batch_size = 1
1206+ ([5 , 3 , 7 ], 1 , False ),
1207+ # Edge case: batch larger than total rows
1208+ ([2 , 3 , 4 ], 100 , False ),
1209+ # Exact division across multiple blocks
1210+ ([6 , 12 , 18 ], 6 , False ),
1211+ ],
1212+ )
1213+ def test_iter_batches_grid (
1214+ ray_start_regular_shared ,
1215+ block_sizes ,
1216+ batch_size ,
1217+ drop_last ,
1218+ ):
11891219 # Tests slicing, batch combining, and partial batch dropping logic over
1190- # a grid of dataset, batching, and dropping configurations.
1191- # Grid: num_blocks x num_rows_block_1 x ... x num_rows_block_N x
1192- # batch_size x drop_last
1193- seed = int (time .time ())
1194- print (f"Seeding RNG for test_iter_batches_grid with: { seed } " )
1195- random .seed (seed )
1196- max_num_blocks = 20
1197- max_num_rows_per_block = 20
1198- num_blocks_samples = 3
1199- block_sizes_samples = 3
1200- batch_size_samples = 3
1201-
1202- for num_blocks in np .random .randint (1 , max_num_blocks + 1 , size = num_blocks_samples ):
1203- block_sizes_list = [
1204- np .random .randint (1 , max_num_rows_per_block + 1 , size = num_blocks )
1205- for _ in range (block_sizes_samples )
1206- ]
1207- for block_sizes in block_sizes_list :
1208- # Create the dataset with the given block sizes.
1209- dfs = []
1210- running_size = 0
1211- for block_size in block_sizes :
1212- dfs .append (
1213- pd .DataFrame (
1214- {"value" : list (range (running_size , running_size + block_size ))}
1215- )
1216- )
1217- running_size += block_size
1218- num_rows = running_size
1219- ds = ray .data .from_blocks (dfs )
1220- for batch_size in np .random .randint (
1221- 1 , num_rows + 1 , size = batch_size_samples
1222- ):
1223- for drop_last in (False , True ):
1224- batches = list (
1225- ds .iter_batches (
1226- batch_size = batch_size ,
1227- drop_last = drop_last ,
1228- batch_format = "pandas" ,
1229- )
1230- )
1231- if num_rows % batch_size == 0 or not drop_last :
1232- # Number of batches should be equal to
1233- # num_rows / batch_size, rounded up.
1234- assert len (batches ) == math .ceil (num_rows / batch_size )
1235- # Concatenated batches should equal the DataFrame
1236- # representation of the entire dataset.
1237- assert pd .concat (batches , ignore_index = True ).equals (
1238- ds .to_pandas ()
1239- )
1240- else :
1241- # Number of batches should be equal to
1242- # num_rows / batch_size, rounded down.
1243- assert len (batches ) == num_rows // batch_size
1244- # Concatenated batches should equal the DataFrame
1245- # representation of the dataset with the partial batch
1246- # remainder sliced off.
1247- assert pd .concat (batches , ignore_index = True ).equals (
1248- ds .to_pandas ()[: batch_size * (num_rows // batch_size )]
1249- )
1250- if num_rows % batch_size == 0 or drop_last :
1251- assert all (len (batch ) == batch_size for batch in batches )
1252- else :
1253- assert all (len (batch ) == batch_size for batch in batches [:- 1 ])
1254- assert len (batches [- 1 ]) == num_rows % batch_size
1220+ # specific dataset, batching, and dropping configurations.
1221+ # Create the dataset with the given block sizes.
1222+ dfs = []
1223+ running_size = 0
1224+ for block_size in block_sizes :
1225+ dfs .append (
1226+ pd .DataFrame (
1227+ {"value" : list (range (running_size , running_size + block_size ))}
1228+ )
1229+ )
1230+ running_size += block_size
1231+ num_rows = running_size
1232+ ds = ray .data .from_blocks (dfs )
1233+
1234+ batches = list (
1235+ ds .iter_batches (
1236+ batch_size = batch_size ,
1237+ drop_last = drop_last ,
1238+ batch_format = "pandas" ,
1239+ )
1240+ )
1241+ if num_rows % batch_size == 0 or not drop_last :
1242+ # Number of batches should be equal to
1243+ # num_rows / batch_size, rounded up.
1244+ assert len (batches ) == math .ceil (num_rows / batch_size )
1245+ # Concatenated batches should equal the DataFrame
1246+ # representation of the entire dataset.
1247+ assert pd .concat (batches , ignore_index = True ).equals (ds .to_pandas ())
1248+ else :
1249+ # Number of batches should be equal to
1250+ # num_rows / batch_size, rounded down.
1251+ assert len (batches ) == num_rows // batch_size
1252+ # Concatenated batches should equal the DataFrame
1253+ # representation of the dataset with the partial batch
1254+ # remainder sliced off.
1255+ assert pd .concat (batches , ignore_index = True ).equals (
1256+ ds .to_pandas ()[: batch_size * (num_rows // batch_size )]
1257+ )
1258+ if num_rows % batch_size == 0 or drop_last :
1259+ assert all (len (batch ) == batch_size for batch in batches )
1260+ else :
1261+ assert all (len (batch ) == batch_size for batch in batches [:- 1 ])
1262+ assert len (batches [- 1 ]) == num_rows % batch_size
12551263
12561264
12571265def test_union (ray_start_regular_shared ):
0 commit comments