-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Offline-ER to work with collate_fn
#390
Conversation
Coverage reportNote Coverage evolution disabled because this PR targets a different branch The coverage rate is
Diff Coverage details (click to unfold)src/renate/memory/buffer.py
src/renate/utils/pytorch.py
src/renate/updaters/learner.py
src/renate/updaters/experimental/offline_er.py
|
Args: | ||
dataset_lengths: The length for the different datasets. | ||
batch_sizes: Batch sizes used for specific datasets. | ||
complete_dataset_iteration: Provide an index to indicate over which dataset to fully |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly rename?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestions?
else num_batches[self.complete_dataset_iteration] | ||
) | ||
|
||
def __iter__(self) -> Iterator[List[int]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add comments about the exact logic?
yield [j for i in samples for j in i] | ||
else: | ||
iterators = [iter(sampler) for sampler in self.subset_samplers] | ||
for s in iterators[self.complete_dataset_iteration]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this optimized? Nested for-loops for each batch seems like a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no nested loop for each batch. it is a single loop over each iterator. in case 1 this is hidden within zip
but it also has a loop over each iterator and calls next.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check this works with distributed training? That uses something like a DistributedSampler
which also modifies the data to sample from.
src/renate/utils/pytorch.py
Outdated
data_start_idx = data_end_idx | ||
self.length = ( | ||
min(num_batches) | ||
if complete_dataset_iteration is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not self.complete_dataset_iteration
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/renate/utils/pytorch.py
Outdated
@@ -156,3 +156,76 @@ def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int | |||
valid_classes: A set of integers of valid classes. | |||
""" | |||
return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes] | |||
|
|||
|
|||
class ConcatRandomSampler(BatchSampler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why inherit from BatchSampler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to Sampler
start_idx = data_start_idx + round(dataset_length / num_replicas * rank) | ||
end_idx = data_start_idx + round(dataset_length / num_replicas * (rank + 1)) | ||
subset_sampler = BatchSampler( | ||
SubsetRandomSampler(list(range(start_idx, end_idx)), generator), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why BatchSampler
of SubsetRandomSampler
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BatchSampler creates batches, SubsetRandomSampler creates random ints from the provided list (List[int] vs int)
|
||
|
||
@pytest.mark.parametrize( | ||
"complete_dataset_iteration,expected_batches", [[None, 2], [0, 7], [1, 5], [2, 2]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For None
batches is 2 because 20//8 = 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. it is identical to the [2, 2] case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So a drop_last is implicit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. improved doc
@@ -11,6 +11,7 @@ | |||
from renate.memory.buffer import ReservoirBuffer | |||
from renate.utils import pytorch | |||
from renate.utils.pytorch import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to add a DistributedSampler
to a test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a unit test for the distributed case instead
Offline-ER applies collate_fn individually on new and memory data. This change will apply the collate function on the entire batch instead.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.