-
Notifications
You must be signed in to change notification settings - Fork 528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve multi-device data loading strategy #2890
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2890 +/- ##
==========================================
+ Coverage 82.11% 82.15% +0.03%
==========================================
Files 871 872 +1
Lines 120693 120968 +275
==========================================
+ Hits 99110 99377 +267
- Misses 21583 21591 +8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
} | ||
|
||
/// Force initialization if needed. | ||
fn initialize(&self) -> &[Box<dyn DynDataLoader<B, O>>] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use Arc<dyn DataLoader<B, O>
here instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah probably, but in the implementation we share each data loader to a thread.
With a box and clone_dyn()
this didn't require any further requirements for Sync
. So if we switch to Arc<dyn DataLoader<B, O>>
we probably need to have another bound (and propagate it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated:
MultiThreadDataLoader
simply keeps a vec of BatchDataLoader
. The Box<dyn DataLoader>
was superfluous, left over from the previous impl. BatchDataLoader
is always used to create an instance.
pub struct MultiThreadDataLoader<B: Backend, I, O> {
// Configuration parameters needed for initialization
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Box<dyn DynBatcher<B, I, O>>,
device: B::Device,
rng: Option<rand::rngs::StdRng>,
num_threads: usize,
// The lazily initialized data loaders
dataloaders: OnceCell<Vec<BatchDataLoader<B, I, O>>>,
}
so initialize()
now returns a slice of BatchDataLoader
s, and we can remove the DynDataLoader
trait with clone_dyn()
.
@@ -175,6 +201,7 @@ impl<LC: LearnerComponents> Learner<LC> { | |||
break; | |||
} | |||
|
|||
// TODO: multi-device validation? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still relavant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot that I added this comment 😅 was mostly for my own sake
The comment is still relevant I guess, cuz the validation loop still runs on a single device (same behavior as before). We just split the training across GPUs.
For most data parallel use cases this is fine I think, unless the validation set is huge splitting across multiple GPUs probably doesn't have as much of a benefit. But we could do it in the future.
Checklist
run-checks all
script has been executed.Related Issues/PRs
#2277
Changes
The data loader strategy currently requires the batcher implementation to load the data on a specific device. For a single GPU, this is fine. But for multi-gpu, the same device (in our examples, the first) is always used and the data then has to be moved to the same device as the model before the inference step.
This is not ideal since a single GPU might get overloaded when loading the data. The
to_device
implementation for our cubecl backends goes through a CPU buffer, which makes things even worse.This PR changes the
Batcher
trait to provide the correct device for the batchTo enable the device selection for data loading, the
DataLoader
trait is now generic over a backend for which a device can be set.Data loaders can also be sliced to return a subset of the data. For multi-device, this allows us to split a single data loader into multiple data loaders to get one data loader per worker (i.e., one per device), so the data is loaded on the correct device before being passed to the model.
To allow the
MultiThreadDataLoader
to be split, it now holds the required fields for the dataset. The data loaders are lazily initialized when.iter()
is called.Testing
Unit tests. Don't have a multi-gpu machine at hand to check the improvements 😅 though changes are fairly straightforward.