@@ -150,6 +150,50 @@ def test_save_indexes(tmpdir):
150
150
)
151
151
152
152
153
+ def test_advanced_indexing_step (dataset ):
154
+ scenario = SegmentationClassIncremental (
155
+ dataset ,
156
+ nb_classes = 4 ,
157
+ increment = 1 ,
158
+ mode = "overlap"
159
+ )
160
+
161
+ with pytest .raises (ValueError ):
162
+ task_set = scenario [0 :4 :2 ]
163
+
164
+
165
+ @pytest .mark .parametrize ("mode,start,end,classes,train" , [
166
+ ("overlap" , 0 , 4 , [1 , 2 , 3 , 4 ], False ),
167
+ ("overlap" , 0 , 4 , [1 , 2 , 3 , 4 ], True ),
168
+ ("overlap" , 3 , 4 , [4 ], True ),
169
+ ("overlap" , 1 , 3 , [2 , 3 ], True ),
170
+ ("disjoint" , 0 , 4 , [1 , 2 , 3 , 4 ], True ),
171
+ ("disjoint" , 3 , 4 , [4 ], True ),
172
+ ("disjoint" , 1 , 3 , [2 , 3 ], True ),
173
+ ("sequential" , 0 , 4 , [1 , 2 , 3 , 4 ], True ),
174
+ ("sequential" , 3 , 4 , [3 , 4 ], True ),
175
+ ("sequential" , 1 , 3 , [1 , 2 , 3 ], True ),
176
+ ])
177
+ def test_advanced_indexing (dataset , dataset_test , mode , start , end , classes , train ):
178
+ scenario = SegmentationClassIncremental (
179
+ dataset if train else dataset_test ,
180
+ nb_classes = 4 ,
181
+ increment = 1 ,
182
+ mode = mode
183
+ )
184
+
185
+ task_set = scenario [start :end ]
186
+ loader = DataLoader (task_set , batch_size = 200 , drop_last = False )
187
+ _ , y , t = next (iter (loader ))
188
+
189
+ t = torch .unique (t )
190
+ y = torch .unique (y )
191
+
192
+ assert len (t ) == 1 and t [0 ] == end - 1
193
+ assert set (y .numpy ().tolist ()) - set ([0 , 255 ]) == set (classes )
194
+
195
+
196
+
153
197
@pytest .mark .parametrize ("mode,all_seen_tasks" , [
154
198
("overlap" , False ),
155
199
("overlap" , True ),
0 commit comments