diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index b62a78fabec3..72b6447da2fa 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -14,7 +14,6 @@ import psutil import gc from math import sqrt -from math import floor from bisect import bisect_left from packaging import version as pkg_version @@ -552,6 +551,7 @@ def prefix_sum_inc(weights): def partition_uniform(num_items, num_parts): + import numpy parts = [0] * (num_parts + 1) # First check for the trivial edge case if num_items <= num_parts: @@ -559,10 +559,15 @@ def partition_uniform(num_items, num_parts): parts[p] = min(p, num_items) return parts - chunksize = floor(num_items / num_parts) - for p in range(num_parts): - parts[p] = min(chunksize * p, num_items) - parts[num_parts] = num_items + chunksize = num_items // num_parts + residual = num_items - (chunksize * num_parts) + + parts = numpy.arange(0, (num_parts + 1) * chunksize, chunksize) + + for i in range(residual): + parts[i + 1:] += 1 + parts = parts.tolist() + return parts