@@ -282,9 +282,11 @@ class MultiSplitInfo(SplitInfo):
282
282
This should only be used to read data and not when producing data.
283
283
"""
284
284
285
- split_infos : list [SplitInfo ] = dataclasses .field (default_factory = list )
285
+ split_infos : list [SplitInfo | SubSplitInfo ] = dataclasses .field (
286
+ default_factory = list
287
+ )
286
288
287
- def __init__ (self , name : str , split_infos : list [SplitInfo ]):
289
+ def __init__ (self , name : str , split_infos : list [SplitInfo | SubSplitInfo ]):
288
290
if not split_infos :
289
291
raise ValueError ('Need to pass a non-empty list of SplitInfos' )
290
292
object .__setattr__ (self , 'split_infos' , split_infos )
@@ -315,6 +317,16 @@ def __repr__(self) -> str:
315
317
f'split_infos={ self .split_infos !r} )'
316
318
)
317
319
320
+ @property
321
+ def examples_in_shards (self ) -> list [int ]:
322
+ result = []
323
+ for split_info in self .split_infos :
324
+ if isinstance (split_info , (SubSplitInfo , MultiSplitInfo )):
325
+ result .extend (split_info .examples_in_shards )
326
+ else :
327
+ result .extend (split_info .shard_lengths )
328
+ return result
329
+
318
330
@property
319
331
def file_instructions (self ) -> list [shard_utils .FileInstruction ]:
320
332
result = []
@@ -361,6 +373,10 @@ class SubSplitInfo:
361
373
def shard_lengths (self ) -> list [int ]:
362
374
return [f .take for f in self .file_instructions ]
363
375
376
+ @property
377
+ def examples_in_shards (self ) -> list [int ]:
378
+ return [f .examples_in_shard for f in self .file_instructions ]
379
+
364
380
@property
365
381
def num_examples (self ) -> int :
366
382
"""Returns the number of example in the subsplit."""
@@ -526,7 +542,7 @@ def _make_absolute_instructions(
526
542
527
543
def _file_instructions_for_split (
528
544
instruction : _AbsoluteInstruction ,
529
- split_info : SplitInfo ,
545
+ split_info : SplitInfo | SubSplitInfo ,
530
546
) -> list [shard_utils .FileInstruction ]:
531
547
"""Returns the file instructions from the given instruction applied to the given split info."""
532
548
if not split_info .num_examples :
@@ -537,9 +553,7 @@ def _file_instructions_for_split(
537
553
return []
538
554
to = split_info .num_examples if instruction .to is None else instruction .to
539
555
if isinstance (split_info , (SubSplitInfo , MultiSplitInfo )):
540
- examples_in_shards = [
541
- f .examples_in_shard for f in split_info .file_instructions
542
- ]
556
+ examples_in_shards = split_info .examples_in_shards
543
557
else :
544
558
examples_in_shards = None
545
559
return shard_utils .get_file_instructions (
0 commit comments