-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] Profiling Improvements #67
Conversation
This looks really useful! Possible design tweak - personally I'd like to get away from coupling too much of the code to CLI args and make it more extensible from python scripts. I separated fsdp_qlora() from main() was to be able to call into fsdp_qlora() from other python scripts without the arg parsing. (ie gradually librarify-ing the core functionality) One problem with the giant flat arg configuration is there's some cross-dependencies between arguments that eventually become inscrutable and end users end up only be able to use the codebase by copy-pasting existing sets of arguments (o/w its too easy to accidentally have some incompatible arguments when starting from scratch). A straw man proposal (cc @KeremTurgutlu @warner-benjamin):
Possible alternatives:
|
Agree that lumping everything into At minimum, subcommands for the training types would be helpful -- More generally, I'm in favor of refactoring This will make the code more extensible and also make it easier to instrument (and optimize) each stage of training. Would also make it easier to create a separate profiling script. |
Yes this sounds like a good direction to me @jeromeku . Feel free to make the change here or if it's too unwieldy for one PR we can split it up. |
I have a slight inclination towards merging this, and then doing the refactoring as a 2nd step -- but happy to go with whatever you both prefer. |
Agree - the refactoring per my previous comment is a larger effort and will take some time to structure cleanly. I think users / developers can benefit from the profiling functionality in the meantime? Will write some documentation and run a few more tests first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, we can follow-up with the refactoring changes in a separate PR.
Feel free to merge when you're ready w/ documentation/test changes. Thanks very much.
profiling_output: str = "", # Output file prefix for profiling | ||
profile: bool_arg = False, # Whether to profile with torch.profiler | ||
profiling_output: str = "profiles", # Output file prefix for profiling | ||
with_stack: bool_arg = False, # Output stacks for profiling. Note that setting export_memory_timeline will automatically export traces since `with_stack` must be true to profile memory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we weren't refactoring, it would probably be good to make profiler-specific args eaisly identifiable (eg with a prof_ prefix). For now we can hold off until the refactoring to address this question of how to organize the arg list though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok if we leave as is? All the profiling args are demarcated from rest of CLI
args and are off by default. Also, documentation clearly lays out the meaning of each of these args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes fine to leave as is.
train.py
Outdated
if args["profiling_output"]: | ||
prof.export_stacks(path = f"{args['profiling_output']}_{local_rank}.txt", | ||
metric = "self_cuda_time_total") | ||
# if args["profiling_output"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can just delete this if its not needed anymore.
Ready for review:
Next steps
|
Thanks very much @jeromeku. Nice work - tagging @KeremTurgutlu @warner-benjamin @jph00 - if no one has objections in the next day let's go ahead and merge. Re: long-term objective - my sense is the original goal was to publish findings + provide a reference implementation. However, it seems like it's also useful as a minimal-dependency implementation for training experiments. If enough people find it useful in that regard we could lean into it in the implementation direction. I'll be interested to learn from your profiling / comparison findings. |
[FEATURE] Profiling Improvements
Motivation
Currently, profiling is enabled only for the entire training run, leading to long profiling times, OOM errors, and problems loading / viewing the resulting traces.
As is, profiling is prohibitively expensive: profiling just 5 steps (batches) of
Llama-2-7B
on a single GPU results in a4GB
trace.Contributions
Key contributions:
perfetto
orchrome::tracing
to analyze training steps at various scales:forward / backwards
, operator runtimes, correlation between CPU / GPU events (aten
operator dispatch ->cuda
runtime -> kernel).torch.profiler
such that profiling is selectively recorded (seeCLI
options below).NOTE: There are additional knobs for
torch.profiler
. In this initial PR, I've included a minimal set that should be sufficient for most use cases.Usage
IMPORTANT
There are issues with recording stack traces and exporting traces simultaneously (see this issue) depending on
python
version. The only combination I was able to get both to work at the same time was withpython=3.11.9
andtorch=2.3.0
.Running the following:
will result in a directory
{model_name}_{train_type}-{local_rank}
with the following artifacts:{model_name}-{train_type}-chrome-trace.json.gz
- interactive trace that can be viewed usingchrome::tracing
orperfetto
{model_name}-{train_type}-key_averages.txt
- sorted table of events, e.g.:{model_name}-{train_type}-memory-timeline.html
- Stacked time series plot of memory use broken down byParameter
,Gradients
,Activations
, etc.{model_name}-{train_type}-stacks.txt
- Stack trace. See docs.Detailed
CLI
options:profile
- whether to profileprofiling_outputs
- output directory fortorch.profiler
artifactsexport_trace
- enables exporting of interactive trace that can be viewed and analyzed usingchrome::tracing
export_memory_timeline
- exports an HTML memory timeline which shows memory use by category (parameters
,activations
,gradients
, etc.)with_stack
- exports stack tracewith_shapes
- adds shapes of operators to the trace{wait, warmup, active}_steps
- controls how many profiling steps are recorded:wait_steps
- number of steps for the profiler to wait before starting to profilewarmup_steps
- number of steps for profiler to profile without recordingactive_steps
- number of steps to recordrepeat
- number of times to repeat the above cycle ofwait, warmup, active
.See docs for further details.
max_steps
- maximum number of batches per epoch. E.g., withnum_epochs=1
, stops training aftermax_steps
of batches. Note that this is automatically adjusted to accommodate the profiler schedule; for example, ifmax_steps < wait_steps + warmup_steps + active_steps
, it will automatically be set towait_steps + warmup_steps + active_steps
such that the profiler can run for at least 1 cycle.The default schedule for the profiler is set such that only 2 steps of the first epoch is recorded (not counting
wait
andwarmup
steps which are not recorded). To record 2 steps from each epoch, adjustrepeat
accordingly.Note that
with_stack
andwith_shapes
are overridden byexport_memory_timeline
since the memory profile requires these options to beTrue
.TODO:
torch.profiler
features