Skip to content

Conversation

radulescupetru
Copy link
Contributor

@radulescupetru radulescupetru commented Sep 24, 2025

Right now, interleave_datasets function with probabilities will sample with replacement. The PR adds the ability to sample without replacement.

import datasets

# Create datasets of different sizes to test exhaustion
data_a = [{"value": i, "source": "A"} for i in range(5)]
data_b = [{"value": i, "source": "B"} for i in range(10, 15)]

ds_a = datasets.Dataset.from_list(data_a).to_iterable_dataset()
ds_b = datasets.Dataset.from_list(data_b).to_iterable_dataset()

# Interleave with probabilities
ds_interleaved = datasets.interleave_datasets(
    [ds_a, ds_b],
    probabilities=[0.6, 0.4],
    seed=42,
    stopping_strategy="all_exhausted",
    sample_with_replacement=True,
)
for i, example in enumerate(ds_interleaved):
    print(f"Sample:{i}: value:{example['value']:02d} source:{example['source']}")

In this example, sample_with_replacement=True and it prints:

Sample:0: value:10 source:B
Sample:1: value:00 source:A
Sample:2: value:11 source:B
Sample:3: value:12 source:B
Sample:4: value:01 source:A
Sample:5: value:13 source:B
Sample:6: value:14 source:B
Sample:7: value:10 source:B
Sample:8: value:02 source:A
Sample:9: value:03 source:A
Sample:10: value:04 source:A

Note that sample with value:10 source: B is sampled twice (Sample:0 and Sample:7)

Re-running with sample_with_replacement=False in prints:

Sample:0: value:10 source:B
Sample:1: value:00 source:A
Sample:2: value:11 source:B
Sample:3: value:12 source:B
Sample:4: value:01 source:A
Sample:5: value:13 source:B
Sample:6: value:14 source:B
Sample:7: value:02 source:A
Sample:8: value:03 source:A
Sample:9: value:04 source:A

Note that we don't see any repeated items.

@radulescupetru
Copy link
Contributor Author

@lhoestq Continuing on the idea from #217
This doesn't add a new stopping criteria, but a new argument to interleave_datasets method. Let me know what you think and if you see a better way of doing this I'm open to suggestions.

@lhoestq
Copy link
Member

lhoestq commented Sep 24, 2025

Great ! this is a cool additions :)

IMO sample_with_replacement as a new argument doesn't make sense if the strategy is "first_exhausted", which is the default, and since disabling replacement affects the stopping strategy, I would be in favor of having it as a new strategy instead

@radulescupetru
Copy link
Contributor Author

Makes sense, here's a revised implementation with that argument removed and adding a new stopping strategy.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@radulescupetru
Copy link
Contributor Author

@lhoestq Let me know if there's anything on my side that I can do!

@LTMeyer
Copy link

LTMeyer commented Sep 29, 2025

Hi @radulescupetru, I'm commenting here after @lhoestq mentioned this PR in #7792. I'm facing a similar problem and I was wondering if there was a common a solution. Let me know if we share the same problem.

As described in the issue, my problem is that I want to mix unbalanced datasets, distribute the samples on multiple workers and ranks, without repeating samples and while retrieving most samples as I can (i.e. without discarding samples whenever they could actually be used). I also noticed that the current approaches interleave_dataset or concatenate_dataset do not leverage all the workers if the number of shards do not align with the number of workers.

@lhoestq
Copy link
Member

lhoestq commented Oct 1, 2025

I pushed a small update @radulescupetru related to @LTMeyer 's issue, I hope you don't mind.

The logic looks all good to me now :) could you also update _interleave_map_style_datasets() in arrow_dataset.py before we merge ? This way the Dataset objects will also benefit from this new stopping strategy.

@radulescupetru radulescupetru changed the title Sample without replacement option for iterable datasets Sample without replacement option when interleaving datasets Oct 3, 2025
@radulescupetru
Copy link
Contributor Author

@lhoestq Thanks for that fix. I've pushed updates to support the new stopping strategy for map style datasets as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants