-
Notifications
You must be signed in to change notification settings - Fork 349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Length of BatchSplittingSampler with Poisson sampling #516
Comments
Hey s-zanella, Thanks for your interest and for the well-documented issue. Based on my understanding and your notebook, I deduce that this does not influence the accounting (which is correct) but only components that wish to access the Labelling this as enhancement. We can take a 2-step plan here:
What would be your thoughts? Pierre |
Thanks @pierrestock for looking into this. I agree with you that this issue doesn't affect the privacy accounting or the correctness of the training process. I'm not aware of any use case that would rely critically on the approximation currently returned by the The 2-step plan you proposed sounds good, except that for the first point I believe that for Poisson sampling, |
I think PyTorch Lightning will actually stop once len is reached. (That’s why I opened and fixed #640, because if you have signal skip and then never execute the last batch, there’s no real optimizer step occurring at all!) Isn’t it possible to look at the PRNG and determine the actual length beforehand? |
The fix for #640 in #641 is incorrect. The new calculation using Computing the actual length beforehand is technically possible by pre-sampling all masks. |
I did full-batch training and #641 fixes the length if you just use BatchMemoryManager and don't use subsampling, so I don't appreciate calling it flat out incorrect, since it doesn't claim to fix this issue. I just commented to mention
|
The title of #640 is BatchSplittingSampler return wrong length and you claimed that #641 [f]ixes #640 by ceiling the number of batches, with no qualifiers. Now you are rolling back that claim and saying that it only applies when not using Poisson sampling. The fix #641 modifies the length computation when using a generic I already agreed with you that it's possible to compute the length precisely (but not deterministically) by pre-sampling all batches. |
🐛 Bug
The
__len__()
method of aBatchSplittingSampler
that wraps aDPDataLoader
is meant to return the number of physical (as opposed to logical) batches in its iterator. Because Poisson sampling produces variable length logical batches, this length is necessarily approximate and will vary between runs. However, the approximation implemented in__len__()
is inaccurate:The actual expected number of physical batches per logical batch is:
where$m$ is the maximum physical batch size $F$ is the CDF of the binomial distribution with
self.max_batch_size
andself.sampler.num_samples
trials andself.sampler.sample_rate
success probability.This can be approximated as e.g.,
Please reproduce using our template Colab and post here the link
Here's a notebook built from the Colab template showing the discrepancy between computed and actual lengths:
https://gist.github.com/s-zanella/b70308db3d6d1b1bf15a5a2c8a1cc525
Expected behavior
It's unclear what is the desired behavior. The length approximation currently implemented is clearly incorrect, but a better approximation doesn't help much because the length of a
BatchSplittingSampler
with Poisson sampling is not fixed. It would be nice to at least warn that the returned length is approximative.From a user point of view, if
BatchMemoryManager
is to be a transparent abstraction, I do not care so much about the number of physical batches processed, but about the number of logical batches. The current abstraction does not signal the beginning/end of logical batches, which makes it hard (impossible without code introspection?) to keep track of the number of logical batches processed so far. Having a mechanism to signal the beginning/end of a logical a batch would solve this issue.The text was updated successfully, but these errors were encountered: