Skip to content

Commit 53f958e

Browse files
Preserve formatting in concatenated IterableDataset (#7522)
* Preserve formatting in concatenated iterable dataset when the inputs have consistent formatting * style * If `dset._formatting` is None for any of the datasets, set the concatenated dataset format to None. Add log line for inputs with inconsistent format. * fix incorrect grouping * Reset output formatting if any of the inputs has formatting not set * log unset format also in case `formatting` is set, but `format_type` is None * Update src/datasets/iterable_dataset.py * Update src/datasets/iterable_dataset.py * Update src/datasets/iterable_dataset.py * Update iterable_dataset.py --------- Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
1 parent e38e6c2 commit 53f958e

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/datasets/iterable_dataset.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3430,6 +3430,25 @@ def _concatenate_iterable_datasets(
34303430
else:
34313431
_check_column_names([col_name for dset in dsets for col_name in dset.features])
34323432

3433+
# Check format is consistent; if so, will set format for concatenated dataset
3434+
if all(dset._formatting is None for dset in dsets):
3435+
formatting = None
3436+
elif any(dset._formatting is None for dset in dsets):
3437+
formatting = None
3438+
logger.info(
3439+
"Some of the datasets have disparate format or format not set. Resetting the format of the concatenated dataset."
3440+
)
3441+
else:
3442+
format_type_set = {dset._formatting.format_type for dset in dsets}
3443+
if len(format_type_set) == 1:
3444+
format_type = format_type_set.pop()
3445+
formatting = FormattingConfig(format_type=format_type)
3446+
else:
3447+
formatting = None
3448+
logger.info(
3449+
"Some of the datasets have disparate format or format not set. Resetting the format of the concatenated dataset."
3450+
)
3451+
34333452
# TODO: improve this to account for a mix of ClassLabel and Value for example
34343453
# right now it would keep the type of the first dataset in the list
34353454
features = Features(
@@ -3451,7 +3470,13 @@ def _concatenate_iterable_datasets(
34513470
# Get all the auth tokens per repository - in case the datasets come from different private repositories
34523471
token_per_repo_id = {repo_id: token for dataset in dsets for repo_id, token in dataset._token_per_repo_id.items()}
34533472
# Return new daset
3454-
return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id)
3473+
return IterableDataset(
3474+
ex_iterable=ex_iterable,
3475+
info=info,
3476+
split=split,
3477+
token_per_repo_id=token_per_repo_id,
3478+
formatting=formatting,
3479+
)
34553480

34563481

34573482
def _interleave_iterable_datasets(

tests/test_iterable_dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,6 +2128,18 @@ def test_concatenate_datasets_axis_1_with_different_lengths():
21282128
assert list(concatenated_dataset) == [{**x, **y} for x, y in zip(extended_dataset2_list, dataset1)]
21292129

21302130

2131+
@require_torch
2132+
@require_tf
2133+
@require_jax
2134+
@pytest.mark.parametrize(
2135+
"format_type", [None, "torch", "python", "tf", "tensorflow", "np", "numpy", "jax", "arrow", "pd", "pandas"]
2136+
)
2137+
def test_concatenate_datasets_with_format(dataset: IterableDataset, format_type):
2138+
formatted_dataset = dataset.with_format(format_type)
2139+
concatenated_dataset = concatenate_datasets([formatted_dataset])
2140+
assert concatenated_dataset._formatting.format_type == get_format_type_from_alias(format_type)
2141+
2142+
21312143
@pytest.mark.parametrize(
21322144
"probas, seed, expected_length, stopping_strategy",
21332145
[

0 commit comments

Comments
 (0)