Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sxjscience committed Apr 23, 2018
1 parent 55f4bd0 commit 4a54ff7
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions gluonnlp/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,15 @@ def __init__(self, lengths, batch_size, num_buckets=10, bucket_keys=None,
assert num_buckets > 0, 'num_buckets must be set when bucket_keys is None. Received ' \
'num_buckets=%d' % num_buckets
if not self._single_element:
bucket_width_l = [max((max_len - min_len) // max(num_buckets - 1, 1), 1)
bucket_width_l = [max((max_len - min_len) // num_buckets, 1)
for max_len, min_len in
zip(max_lengths, min_lengths)]
bucket_keys =\
[tuple(max(max_len - i * width, min_len) for max_len, min_len, width in
zip(max_lengths, min_lengths, bucket_width_l))
for i in range(num_buckets)]
else:
bucket_width = max((max_lengths - min_lengths) // max(num_buckets - 1, 1), 1)
bucket_width = max((max_lengths - min_lengths) // num_buckets, 1)
bucket_keys = [max(max_lengths - i * bucket_width, min_lengths)
for i in range(num_buckets)]
else:
Expand Down Expand Up @@ -208,6 +208,8 @@ def __init__(self, lengths, batch_size, num_buckets=10, bucket_keys=None,
def __iter__(self):
if self._shuffle:
np.random.shuffle(self._batch_infos)
for bucket_id in range(len(self._bucket_keys)):
np.random.shuffle(self._bucket_sample_ids[bucket_id])
for bucket_id, batch_begin in self._batch_infos:
batch_size = self._bucket_batch_sizes[bucket_id]
batch_end = min(batch_begin + batch_size, len(self._bucket_sample_ids[bucket_id]))
Expand Down

0 comments on commit 4a54ff7

Please sign in to comment.