-
Notifications
You must be signed in to change notification settings - Fork 814
Setting sort_within_batch of pool to True when called in the method create_batches of BucketIterator? #641
Description
Hi,
I am working with torchtext and I had a question about the pool function in the iterator module.
I have a train_dataset, valid_dataset and test_dataset. I want to create a train iterator with minibatches of similar lengths, with random internal order, and eventually shuffle the order of the minibatches. For the valid and test set, I want to keep their initial orders and create batches sequentially, based on that order.
I found the splits / init method of BucketIterator rather counterintuitive to use and I agree with this post:
While Torchtext is brilliant, it’s sort_key based batching leaves a little to be desired. Often the sentences aren’t of the same length at all, and you end up feeding a lot of padding into your network
I'm a bit confused with why the argument sort_within_batch of pool is set to self.sort_within_batch when pool is called in the method create_batches of BucketIterator.
My issue is that if I want to effectively create minibatches of similar lengths, I have to set sort_within_batch to True when I call data.BucketIterator.splits.
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_dataset,
valid_dataset,
test_dataset),
batch_sizes=(train_batch_size,
valid_batch_size,
test_batch_size),
sort_key=lambda x: len(x.text),
sort=False,
sort_within_batch=True)But then, in addition to sort the samples in the chunks / buckets, it will also sort the samples in the created minibatches (in iter), which is not needed. As a side effect it will also sort both the chunks / buckets and the minibatches in the validation and test iterators, which I don't want.
Wouldn't it be more intuitive to set the argument sort_within_batch of pool to True when called in the method create_batches of BucketIterator?
In that case, if you don't want to reorder the samples in the minibatches of the train, validation nor test set, you would do
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_dataset,
valid_dataset,
test_dataset),
batch_sizes=(train_batch_size,
valid_batch_size,
test_batch_size),
sort_key=lambda x: len(x.text),
sort=False,
sort_within_batch=False)If you want to sort the samples in the minibatches of the validation and test set, you would do
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_dataset,
valid_dataset,
test_dataset),
batch_sizes=(train_batch_size,
valid_batch_size,
test_batch_size),
sort_key=lambda x: len(x.text))And if you want to sort the samples in the minibatches of the train, validation and test set, you would do
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits((train_dataset,
valid_dataset,
test_dataset),
batch_sizes=(train_batch_size,
valid_batch_size,
test_batch_size),
sort_key=lambda x: len(x.text),
sort_within_batch=True)Besides, you may want to let the factor 100 (for the chunk / bucket size) as a paremeter because it can be useful to tune it when working with toy datasets.
Please tell me if I am missing something.
Thanks.