Skip to content

Commit

Permalink
Merge pull request #66 from AnswerDotAI/avh_dev
Browse files Browse the repository at this point in the history
Add profiling to train.py
  • Loading branch information
austinvhuang authored May 20, 2024
2 parents 1a9fddf + fdc3c7e commit ed43127
Show file tree
Hide file tree
Showing 4 changed files with 625 additions and 487 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The following steps should work (tested on Cuda 11.7, 11.8 and 12.1):
- Install bitsandbytes `pip install bitsandbytes>=0.43.0`
- Run `huggingface-cli login` (to access Llama 2)
- Optional Libraries:
- HQQ quantization: follow the HQQ installation [instructions](https://github.com/mobiusml/hqq?tab=readme-ov-file#installation). Our training script uses `HQQBackend.ATEN_BACKPROP`, so also make sure to build the custom kernels `cd hqq/kernels && python setup_cuda.py install`. Pin commit to `72b2b641aadc44a7ded6b243915f90df3b3be385` for FSDP compatibility, until `to_empty()` method is fixed.
- HQQ quantization: follow the HQQ installation [instructions](https://github.com/mobiusml/hqq?tab=readme-ov-file#installation). Our training script uses `HQQBackend.ATEN_BACKPROP`, so also make sure to build the custom kernels `cd hqq/kernels && python setup_cuda.py install`.
- Weights and Biases logging: `pip install wandb`
- [Pytorch >= 2.2](https://pytorch.org/blog/pytorch2-2/) is recommended to make use of the native flash-attention 2 kernel.

Expand Down Expand Up @@ -319,4 +319,4 @@ Finally, add gradient checkpointing support by adding the transformer layer to `
```python
if args["use_gradient_checkpointing"]:
check_fn = lambda submodule: isinstance(submodule, (LlamaDecoderLayer, MistralDecoderLayer))
```
```
28 changes: 28 additions & 0 deletions profile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This generates uses export_stacks to generate profiling output
# /tmp/profile_0.txt, /tmp/profile_1.txt, etc. (1 file per process.)
#
# Output files are generated using export_stacks(), note there are some
# outstanding issues to be aware of:
# https://github.com/pytorch/pytorch/issues/100253
#
# Profiling output files can be used with speedscope or other tools.
#
# For additional information, see:
# https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html

python train.py \
--model_name meta-llama/Meta-Llama-3-8B \
--train_type hqq_dora \
--n_bits 4 \
--precision bf16 \
--dataset orca_math \
--dataset_samples 8 \
--batch_size 2 \
--context_length 512 \
--gradient_accumulation_steps 2 \
--use_gradient_checkpointing False \
--use_cpu_offload False \
--use_activation_cpu_offload False \
--save_model False \
--profiling_output /tmp/profile

Loading

0 comments on commit ed43127

Please sign in to comment.