Skip to content

Commit e744cc0

Browse files
author
The TensorFlow Datasets Authors
committed
Fix crash when taking a subset of a MultiSplitInfo with empty shard.
PiperOrigin-RevId: 750755958
1 parent dce2881 commit e744cc0

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

tensorflow_datasets/core/splits.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,11 @@ class MultiSplitInfo(SplitInfo):
282282
This should only be used to read data and not when producing data.
283283
"""
284284

285-
split_infos: list[SplitInfo] = dataclasses.field(default_factory=list)
285+
split_infos: list[SplitInfo | SubSplitInfo] = dataclasses.field(
286+
default_factory=list
287+
)
286288

287-
def __init__(self, name: str, split_infos: list[SplitInfo]):
289+
def __init__(self, name: str, split_infos: list[SplitInfo | SubSplitInfo]):
288290
if not split_infos:
289291
raise ValueError('Need to pass a non-empty list of SplitInfos')
290292
object.__setattr__(self, 'split_infos', split_infos)
@@ -315,6 +317,16 @@ def __repr__(self) -> str:
315317
f'split_infos={self.split_infos!r})'
316318
)
317319

320+
@property
321+
def examples_in_shards(self) -> list[int]:
322+
result = []
323+
for split_info in self.split_infos:
324+
if isinstance(split_info, (SubSplitInfo, MultiSplitInfo)):
325+
result.extend(split_info.examples_in_shards)
326+
else:
327+
result.extend(split_info.shard_lengths)
328+
return result
329+
318330
@property
319331
def file_instructions(self) -> list[shard_utils.FileInstruction]:
320332
result = []
@@ -361,6 +373,10 @@ class SubSplitInfo:
361373
def shard_lengths(self) -> list[int]:
362374
return [f.take for f in self.file_instructions]
363375

376+
@property
377+
def examples_in_shards(self) -> list[int]:
378+
return [f.examples_in_shard for f in self.file_instructions]
379+
364380
@property
365381
def num_examples(self) -> int:
366382
"""Returns the number of example in the subsplit."""
@@ -526,7 +542,7 @@ def _make_absolute_instructions(
526542

527543
def _file_instructions_for_split(
528544
instruction: _AbsoluteInstruction,
529-
split_info: SplitInfo,
545+
split_info: SplitInfo | SubSplitInfo,
530546
) -> list[shard_utils.FileInstruction]:
531547
"""Returns the file instructions from the given instruction applied to the given split info."""
532548
if not split_info.num_examples:
@@ -537,9 +553,7 @@ def _file_instructions_for_split(
537553
return []
538554
to = split_info.num_examples if instruction.to is None else instruction.to
539555
if isinstance(split_info, (SubSplitInfo, MultiSplitInfo)):
540-
examples_in_shards = [
541-
f.examples_in_shard for f in split_info.file_instructions
542-
]
556+
examples_in_shards = split_info.examples_in_shards
543557
else:
544558
examples_in_shards = None
545559
return shard_utils.get_file_instructions(

tensorflow_datasets/core/splits_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,43 @@ def test_multi_split_sub_split(self):
255255
self.assertEqual(file_instruction.take, 2)
256256
self.assertEqual(file_instruction.examples_in_shard, 10)
257257

258+
def test_multi_split_empty_shard(self):
259+
split_info = splits.MultiSplitInfo(
260+
name='train',
261+
split_infos=[
262+
splits.SplitInfo(
263+
name='train',
264+
shard_lengths=[5, 0, 5],
265+
num_bytes=0,
266+
filename_template=_filename_template(
267+
split='train', data_dir='/abc'
268+
),
269+
),
270+
],
271+
)
272+
split_dict = splits.SplitDict([split_info])
273+
sub_split = split_dict['train[:90%]']
274+
self.assertEqual(sub_split.name, 'train[:90%]')
275+
self.assertEqual(sub_split.num_examples, 9)
276+
self.assertEqual(sub_split.shard_lengths, [5, 4])
277+
self.assertEqual(
278+
sub_split.file_instructions,
279+
[
280+
shard_utils.FileInstruction(
281+
filename='/abc/ds_name-train.tfrecord-00000-of-00003',
282+
skip=0,
283+
take=5,
284+
examples_in_shard=5,
285+
),
286+
shard_utils.FileInstruction(
287+
filename='/abc/ds_name-train.tfrecord-00002-of-00003',
288+
skip=0,
289+
take=4,
290+
examples_in_shard=5,
291+
),
292+
],
293+
)
294+
258295

259296
class SplitsTest(testing.TestCase):
260297

0 commit comments

Comments
 (0)