Skip to content

Commit 1b8522e

Browse files
authored
[AUTOTVM] Fix a bug in generating the search space (#4779)
- Do not use numpy.prod which ignores integer (64 bits) overflows. This leads to an incorrect number of points in the search space.
1 parent 3827ccb commit 1b8522e

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

python/tvm/autotvm/task/space.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def __init__(self, axes, policy, **kwargs):
226226
def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
227227
"""Generate space by DFS"""
228228
if now == self.num_output - 1:
229-
prod = np.prod(tmp_stack, dtype=np.int64)
229+
prod = functools.reduce(lambda x, y: x * y, tmp_stack)
230+
if prod > self.product:
231+
return
230232
if self.product % prod == 0 or (not enforce_no_tail and prod < self.product):
231233
self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
232234
else:

tests/python/unittest/test_autotvm_space.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ def test_split():
6262
cfg.define_split('tile_c', cfg.axis(224), policy='verbose', num_outputs=3)
6363
assert len(cfg.space_map['tile_c']) == 84
6464

65+
# Count the number of non-negative integer solutions of a + b + c + d = n
66+
def count4(n):
67+
cnt = 0
68+
for a in range(0, n + 1):
69+
for b in range(0, n - a + 1):
70+
cnt += n - a - b + 1
71+
return cnt
72+
73+
# test overflow
74+
n = 25
75+
cfg = ConfigSpace()
76+
cfg.define_split('x', cfg.axis(2**n), policy='factors', num_outputs=4)
77+
# count4(25) is 3276.
78+
assert len(cfg.space_map['x']) == count4(n)
79+
6580
# test fallback
6681
cfg = FallbackConfigEntity()
6782
cfg.define_split('tile_n', cfg.axis(128), num_outputs=3)

0 commit comments

Comments
 (0)