forked from facebookresearch/maskrcnn-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_data_samplers.py
153 lines (122 loc) · 5.4 KB
/
test_data_samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import itertools
import random
import unittest
from torch.utils.data.sampler import BatchSampler
from torch.utils.data.sampler import Sampler
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data.sampler import RandomSampler
from maskrcnn_benchmark.data.samplers import GroupedBatchSampler
from maskrcnn_benchmark.data.samplers import IterationBasedBatchSampler
class SubsetSampler(Sampler):
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
class TestGroupedBatchSampler(unittest.TestCase):
def test_respect_order_simple(self):
drop_uneven = False
dataset = [i for i in range(40)]
group_ids = [i // 10 for i in dataset]
sampler = SequentialSampler(dataset)
for batch_size in [1, 3, 5, 6]:
batch_sampler = GroupedBatchSampler(
sampler, group_ids, batch_size, drop_uneven
)
result = list(batch_sampler)
merged_result = list(itertools.chain.from_iterable(result))
self.assertEqual(merged_result, dataset)
def test_respect_order(self):
drop_uneven = False
dataset = [i for i in range(10)]
group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
sampler = SequentialSampler(dataset)
expected = [
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
[[0, 1, 3], [2, 4, 5], [6, 9], [7, 8]],
[[0, 1, 3, 6], [2, 4, 5, 7], [8], [9]],
]
for idx, batch_size in enumerate([1, 3, 4]):
batch_sampler = GroupedBatchSampler(
sampler, group_ids, batch_size, drop_uneven
)
result = list(batch_sampler)
self.assertEqual(result, expected[idx])
def test_respect_order_drop_uneven(self):
batch_size = 3
drop_uneven = True
dataset = [i for i in range(10)]
group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
sampler = SequentialSampler(dataset)
batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
result = list(batch_sampler)
expected = [[0, 1, 3], [2, 4, 5]]
self.assertEqual(result, expected)
def test_subset_sampler(self):
batch_size = 3
drop_uneven = False
dataset = [i for i in range(10)]
group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
sampler = SubsetSampler([0, 3, 5, 6, 7, 8])
batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
result = list(batch_sampler)
expected = [[0, 3, 6], [5, 7, 8]]
self.assertEqual(result, expected)
def test_permute_subset_sampler(self):
batch_size = 3
drop_uneven = False
dataset = [i for i in range(10)]
group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
sampler = SubsetSampler([5, 0, 6, 1, 3, 8])
batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
result = list(batch_sampler)
expected = [[5, 8], [0, 6, 1], [3]]
self.assertEqual(result, expected)
def test_permute_subset_sampler_drop_uneven(self):
batch_size = 3
drop_uneven = True
dataset = [i for i in range(10)]
group_ids = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
sampler = SubsetSampler([5, 0, 6, 1, 3, 8])
batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
result = list(batch_sampler)
expected = [[0, 6, 1]]
self.assertEqual(result, expected)
def test_len(self):
batch_size = 3
drop_uneven = True
dataset = [i for i in range(10)]
group_ids = [random.randint(0, 1) for _ in dataset]
sampler = RandomSampler(dataset)
batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
result = list(batch_sampler)
self.assertEqual(len(result), len(batch_sampler))
self.assertEqual(len(result), len(batch_sampler))
batch_sampler = GroupedBatchSampler(sampler, group_ids, batch_size, drop_uneven)
batch_sampler_len = len(batch_sampler)
result = list(batch_sampler)
self.assertEqual(len(result), batch_sampler_len)
self.assertEqual(len(result), len(batch_sampler))
class TestIterationBasedBatchSampler(unittest.TestCase):
def test_number_of_iters_and_elements(self):
for batch_size in [2, 3, 4]:
for num_iterations in [4, 10, 20]:
for drop_last in [False, True]:
dataset = [i for i in range(10)]
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(
sampler, batch_size, drop_last=drop_last
)
iter_sampler = IterationBasedBatchSampler(
batch_sampler, num_iterations
)
assert len(iter_sampler) == num_iterations
for i, batch in enumerate(iter_sampler):
start = (i % len(batch_sampler)) * batch_size
end = min(start + batch_size, len(dataset))
expected = [x for x in range(start, end)]
self.assertEqual(batch, expected)
if __name__ == "__main__":
unittest.main()