@@ -85,37 +85,65 @@ def check(starting_step, ds):
85
85
dict (
86
86
sources = [slice (0 , 10 , 2 ), slice (1 , 5 , 2 )],
87
87
weights = [1 , 1 ],
88
+ is_iter_dataset = [False , False ],
88
89
take = 10 ,
89
90
expected = [0 , 1 , 2 , 3 , 4 , 1 , 6 , 3 , 8 , 1 ],
90
91
),
91
92
dict (
92
93
sources = [slice (0 , 10 , 2 ), slice (1 , 5 , 2 )],
93
94
weights = [2 , 1 ],
95
+ is_iter_dataset = [False , False ],
94
96
take = 10 ,
95
97
expected = [0 , 2 , 1 , 4 , 6 , 3 , 8 , 0 , 2 , 1 ],
96
98
),
97
99
dict (
98
100
sources = [slice (0 , 10 , 2 ), slice (1 , 5 , 2 )],
99
101
weights = [1 , 1e-9 ],
102
+ is_iter_dataset = [False , False ],
100
103
take = 10 ,
101
104
expected = [0 , 2 , 4 , 6 , 8 , 0 , 2 , 4 , 6 , 8 ],
102
105
),
106
+ # IterDataset
107
+ dict (
108
+ sources = [slice (0 , 10 , 2 ), slice (1 , 5 , 2 )],
109
+ weights = [1 , 1 ],
110
+ is_iter_dataset = [True , True ],
111
+ take = 10 ,
112
+ expected = [0 , 1 , 2 , 3 , 4 , 1 , 6 , 3 , 8 , 1 ],
113
+ ),
114
+ # Mixture of IterDataset and MapDataset.
115
+ dict (
116
+ sources = [slice (0 , 10 , 2 ), slice (1 , 5 , 2 )],
117
+ weights = [1 , 1 ],
118
+ is_iter_dataset = [True , False ],
119
+ take = 10 ,
120
+ expected = [0 , 1 , 2 , 3 , 4 , 1 , 6 , 3 , 8 , 1 ],
121
+ ),
103
122
)
104
123
def test_sample_from_datasets (
105
124
self ,
106
125
sources : list [slice ],
107
126
weights : list [int ],
127
+ is_iter_dataset : list [bool ],
108
128
take : Optional [int ],
109
129
expected : list [int ],
110
130
):
131
+ sources = [
132
+ range_dataset (start = src .start , stop = src .stop , step = src .step ).repeat () for src in sources
133
+ ]
134
+ sources = [
135
+ source .to_iter_dataset () if should_convert else source
136
+ for source , should_convert in zip (sources , is_iter_dataset )
137
+ ]
111
138
ds = sample_from_datasets (
112
- sources = [
113
- range_dataset (start = src .start , stop = src .stop , step = src .step ) for src in sources
114
- ],
139
+ sources = sources ,
115
140
weights = weights ,
116
141
)
117
- ds = ds .slice (slice (0 , take ))
118
- self .assertCountEqual (expected , list (ds ))
142
+ ds_iter = iter (ds )
143
+ result = []
144
+ for _ in range (take ):
145
+ result .append (next (ds_iter ))
146
+ self .assertCountEqual (expected , list (result ))
119
147
120
148
def test_sample_from_datasets_errors (self ):
121
149
ds = range_dataset (start = 0 , stop = 2 )
@@ -124,17 +152,12 @@ def test_sample_from_datasets_errors(self):
124
152
repeated_ds = sample_from_datasets (sources = [ds ], weights = [1 ]).slice (slice (0 , 4 ))
125
153
self .assertEqual ([0 , 1 , 0 , 1 ], list (repeated_ds ))
126
154
127
- # Make sure that non-map dataset raises.
128
- with self .assertRaisesRegex (ValueError , "MapDataset" ):
129
- ds = ds .to_iter_dataset ()
130
- sample_from_datasets (sources = [ds ], weights = [1 ])
131
-
132
155
def test_shuffle_dataset (self ):
133
156
# Test without repeat.
134
157
ds = sample_from_datasets (
135
158
sources = [
136
- range_dataset (start = 0 , stop = 10 , step = 2 ),
137
- range_dataset (start = 1 , stop = 5 , step = 2 ),
159
+ range_dataset (start = 0 , stop = 10 , step = 2 ). repeat () ,
160
+ range_dataset (start = 1 , stop = 5 , step = 2 ). repeat () ,
138
161
],
139
162
weights = [2 , 1 ],
140
163
)
@@ -174,7 +197,7 @@ def test_slice_dataset(self, s: slice, expected: list[int]):
174
197
175
198
def test_batch (self ):
176
199
# [0, 1, 2, 3, 4].
177
- ds = range_dataset (start = 0 , stop = 5 , seed = 123 )
200
+ ds = range_dataset (start = 0 , stop = 5 , seed = 123 ). repeat ()
178
201
# [1, 2, 3, 4, 5].
179
202
other_ds = ds .map (_PlusOne ())
180
203
# [0, 1, 2, 1, 3, 4, 2, 5, 1, 3, ...].
0 commit comments