Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Profiling Improvements #67

Merged
merged 17 commits into from
Jun 11, 2024
Merged

Conversation

jeromeku
Copy link
Contributor

@jeromeku jeromeku commented May 23, 2024

[FEATURE] Profiling Improvements

Motivation

Currently, profiling is enabled only for the entire training run, leading to long profiling times, OOM errors, and problems loading / viewing the resulting traces.

As is, profiling is prohibitively expensive: profiling just 5 steps (batches) of Llama-2-7B on a single GPU results in a 4GB trace.

Contributions

Key contributions:

  • exporting interactive traces that can be viewed using perfetto or chrome::tracing to analyze training steps at various scales: forward / backwards, operator runtimes, correlation between CPU / GPU events (aten operator dispatch -> cuda runtime -> kernel).
  • exporting a summary table of events sorted by kernel runtime
  • add a scheduler to torch.profiler such that profiling is selectively recorded (see CLI options below).

NOTE: There are additional knobs for torch.profiler. In this initial PR, I've included a minimal set that should be sufficient for most use cases.

Usage

IMPORTANT
There are issues with recording stack traces and exporting traces simultaneously (see this issue) depending on python version. The only combination I was able to get both to work at the same time was with python=3.11.9 and torch=2.3.0.

Running the following:

python train.py \
--model_name "meta-llama/Llama-2-7b-hf" \
--train_type qlora \
--profile true \ 
--export_trace true \
--export_memory_timeline true \
--max_steps 10

will result in a directory {model_name}_{train_type}-{local_rank} with the following artifacts:

  • {model_name}-{train_type}-chrome-trace.json.gz - interactive trace that can be viewed using chrome::tracing or perfetto
  • {model_name}-{train_type}-key_averages.txt - sorted table of events, e.g.:
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem # of Calls Source Location
ProfilerStep* 0.00% 0.000us 0.00% 0.000us 0.000us 4.816s 44.60% 4.816s 963.233ms 0 b 0 b 0 b 0 b 5 <built-in method to of Tensor object at 0x7f20bf709310>
train.py(962): fsdp_main
torch/multiprocessing/spawn.py(75): _wrap
multiprocessing/process.py(108): run
multiprocessing/process.py(314): _bootstrap
FullyShardedDataParallel.forward 0.00% 0.000us 0.00% 0.000us 0.000us 2.208s 20.45% 2.208s 441.555ms 0 b 0 b 0 b 0 b 5 <built-in method embedding of type object at 0x7f21e21797c0>
torch/nn/functional.py(2154): embedding
torch/nn/modules/sparse.py(162): forward
torch/nn/modules/module.py(1534): _call_impl
nn.Module: Embedding_0
aten::mm 0.44% 31.314ms 0.69% 48.739ms 43.517us 332.421ms 3.08% 337.208ms 301.079us 0 b 0 b 3.26 Gb 3.26 Gb 1120
bitsandbytes/autograd/_functions.py(492): forward
<built-in method apply of FunctionMeta object at 0x827a410>
torch/autograd/function.py(582): apply
bitsandbytes/autograd/_functions.py(559): matmul_4bit
MatMul4Bit 2.81% 198.511ms 4.93% 347.437ms 310.212us 284.169ms 2.63% 630.417ms 562.872us 0 b 0 b 3.26 Gb -62.31 Gb 1120 <built-in method apply of FunctionMeta object at 0x827a410>
torch/autograd/function.py(582): apply
bitsandbytes/autograd/_functions.py(559): matmul_4bit
bitsandbytes/nn/modules.py(442): forward
torch/nn/modules/module.py(1534): _call_impl
  • {model_name}-{train_type}-memory-timeline.html - Stacked time series plot of memory use broken down by Parameter, Gradients, Activations, etc.
  • {model_name}-{train_type}-stacks.txt - Stack trace. See docs.

Detailed CLI options:

  • profile - whether to profile
  • profiling_outputs - output directory for torch.profiler artifacts
  • export_trace - enables exporting of interactive trace that can be viewed and analyzed using chrome::tracing
  • export_memory_timeline - exports an HTML memory timeline which shows memory use by category (parameters, activations, gradients, etc.)
  • with_stack - exports stack trace
  • with_shapes - adds shapes of operators to the trace
  • {wait, warmup, active}_steps - controls how many profiling steps are recorded:
    • wait_steps - number of steps for the profiler to wait before starting to profile
    • warmup_steps - number of steps for profiler to profile without recording
    • active_steps - number of steps to record
    • repeat - number of times to repeat the above cycle of wait, warmup, active.
      See docs for further details.
  • max_steps - maximum number of batches per epoch. E.g., with num_epochs=1, stops training after max_steps of batches. Note that this is automatically adjusted to accommodate the profiler schedule; for example, if max_steps < wait_steps + warmup_steps + active_steps, it will automatically be set to wait_steps + warmup_steps + active_steps such that the profiler can run for at least 1 cycle.

The default schedule for the profiler is set such that only 2 steps of the first epoch is recorded (not counting wait and warmup steps which are not recorded). To record 2 steps from each epoch, adjust repeat accordingly.

Note that with_stack and with_shapes are overridden by export_memory_timeline since the memory profile requires these options to be True.

TODO:

  • Test in multi-gpu environment
  • Additional documented and undocumented torch.profiler features

@austinvhuang
Copy link
Contributor

This looks really useful!

Possible design tweak - personally I'd like to get away from coupling too much of the code to CLI args and make it more extensible from python scripts. I separated fsdp_qlora() from main() was to be able to call into fsdp_qlora() from other python scripts without the arg parsing. (ie gradually librarify-ing the core functionality)

One problem with the giant flat arg configuration is there's some cross-dependencies between arguments that eventually become inscrutable and end users end up only be able to use the codebase by copy-pasting existing sets of arguments (o/w its too easy to accidentally have some incompatible arguments when starting from scratch).

A straw man proposal (cc @KeremTurgutlu @warner-benjamin):

  • profile.sh becomes a python script (could be a main() in profiling_tools.py or its own source file) which accepts its own arguments (eg batch_size). Conversely it might keep some things fixed (eg only works on llama 8b) that are configurable in the train.py script for the sake of simplification.
  • The main() entrypoint for the profiling script either invokes fsdp_qlora or does its own mp spawn of fsdp_main()
  • keep fsdp_main() decoupled from CLI args and instead of an args: dict parameter, it takes a namedtuple struct defining a ProfilingConfig type with profiling-specific arguments.

Possible alternatives:

  • Leave this as is
  • Achieve some modular functionality with subcommands instead of multiple scripts with their own main() functions.

@austinvhuang austinvhuang self-requested a review May 23, 2024 22:23
@jeromeku
Copy link
Contributor Author

@austinvhuang

Agree that lumping everything into train.py is already getting unwieldy and will ultimately result in a poor user experience.

At minimum, subcommands for the training types would be helpful -- full finetune, lora, qlora -- which would share common args but also have type-specific options.

More generally, I'm in favor of refactoring fsdp_main into component stages: setup (model + data loading, world state, etc.), train step (forwards + backwards), optimizer step, and teardown.

This will make the code more extensible and also make it easier to instrument (and optimize) each stage of training. Would also make it easier to create a separate profiling script.

@austinvhuang
Copy link
Contributor

Yes this sounds like a good direction to me @jeromeku . Feel free to make the change here or if it's too unwieldy for one PR we can split it up.

@jph00
Copy link
Contributor

jph00 commented May 24, 2024

I have a slight inclination towards merging this, and then doing the refactoring as a 2nd step -- but happy to go with whatever you both prefer.

@jeromeku
Copy link
Contributor Author

jeromeku commented May 25, 2024

@jph00 @austinvhuang

Agree - the refactoring per my previous comment is a larger effort and will take some time to structure cleanly. I think users / developers can benefit from the profiling functionality in the meantime?

Will write some documentation and run a few more tests first.

Copy link
Contributor

@austinvhuang austinvhuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we can follow-up with the refactoring changes in a separate PR.

Feel free to merge when you're ready w/ documentation/test changes. Thanks very much.

profiling_output: str = "", # Output file prefix for profiling
profile: bool_arg = False, # Whether to profile with torch.profiler
profiling_output: str = "profiles", # Output file prefix for profiling
with_stack: bool_arg = False, # Output stacks for profiling. Note that setting export_memory_timeline will automatically export traces since `with_stack` must be true to profile memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we weren't refactoring, it would probably be good to make profiler-specific args eaisly identifiable (eg with a prof_ prefix). For now we can hold off until the refactoring to address this question of how to organize the arg list though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok if we leave as is? All the profiling args are demarcated from rest of CLI args and are off by default. Also, documentation clearly lays out the meaning of each of these args.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes fine to leave as is.

train.py Outdated
if args["profiling_output"]:
prof.export_stacks(path = f"{args['profiling_output']}_{local_rank}.txt",
metric = "self_cuda_time_total")
# if args["profiling_output"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just delete this if its not needed anymore.

@jeromeku
Copy link
Contributor Author

jeromeku commented Jun 4, 2024

@austinvhuang

Ready for review:

  • Refactored profiling tools to expose additional functionality and increase robustness
  • Added documentation (see PROFILING.md).
  • Ran on 2-GPU setup for various profiling configs to test that artifacts are correctly exported. Note that torch.profiler.profile can be finicky when exporting certain artifacts (see the warning in the documentation).
  • Note that I accidentally ran a formatter (ruff) when saving train.py which is why there appears to be a large diff for this file...

Next steps

  • Refactoring - is it necessary to use fastcore as the command line parser? Would be easier to use something like argparser to create subcommands to cleanly expose the different training options.
  • What is long-term objective of fsdp-qlora? Would be helpful in guiding how to best structure.
  • Going to do some profiling against other distributed frameworks to benchmark / optimize performance.

@austinvhuang
Copy link
Contributor

Thanks very much @jeromeku. Nice work - tagging @KeremTurgutlu @warner-benjamin @jph00 - if no one has objections in the next day let's go ahead and merge.

Re: long-term objective - my sense is the original goal was to publish findings + provide a reference implementation. However, it seems like it's also useful as a minimal-dependency implementation for training experiments. If enough people find it useful in that regard we could lean into it in the implementation direction. I'll be interested to learn from your profiling / comparison findings.

@austinvhuang austinvhuang merged commit cec4386 into AnswerDotAI:main Jun 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants