Skip to content

Commit 877c1e2

Browse files
Add tests for advanced indexing for segmentation.
1 parent 2a6380a commit 877c1e2

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/test_segmentation.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,50 @@ def test_save_indexes(tmpdir):
150150
)
151151

152152

153+
def test_advanced_indexing_step(dataset):
154+
scenario = SegmentationClassIncremental(
155+
dataset,
156+
nb_classes=4,
157+
increment=1,
158+
mode="overlap"
159+
)
160+
161+
with pytest.raises(ValueError):
162+
task_set = scenario[0:4:2]
163+
164+
165+
@pytest.mark.parametrize("mode,start,end,classes,train", [
166+
("overlap", 0, 4, [1, 2, 3, 4], False),
167+
("overlap", 0, 4, [1, 2, 3, 4], True),
168+
("overlap", 3, 4, [4], True),
169+
("overlap", 1, 3, [2, 3], True),
170+
("disjoint", 0, 4, [1, 2, 3, 4], True),
171+
("disjoint", 3, 4, [4], True),
172+
("disjoint", 1, 3, [2, 3], True),
173+
("sequential", 0, 4, [1, 2, 3, 4], True),
174+
("sequential", 3, 4, [3, 4], True),
175+
("sequential", 1, 3, [1, 2, 3], True),
176+
])
177+
def test_advanced_indexing(dataset, dataset_test, mode, start, end, classes, train):
178+
scenario = SegmentationClassIncremental(
179+
dataset if train else dataset_test,
180+
nb_classes=4,
181+
increment=1,
182+
mode=mode
183+
)
184+
185+
task_set = scenario[start:end]
186+
loader = DataLoader(task_set, batch_size=200, drop_last=False)
187+
_, y, t = next(iter(loader))
188+
189+
t = torch.unique(t)
190+
y = torch.unique(y)
191+
192+
assert len(t) == 1 and t[0] == end - 1
193+
assert set(y.numpy().tolist()) - set([0, 255]) == set(classes)
194+
195+
196+
153197
@pytest.mark.parametrize("mode,all_seen_tasks", [
154198
("overlap", False),
155199
("overlap", True),

0 commit comments

Comments
 (0)