-
Notifications
You must be signed in to change notification settings - Fork 414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add gemma2b variants #1835
base: main
Are you sure you want to change the base?
Changes from 1 commit
9685e0b
e87f878
6f89920
0d53660
6b50916
54a237c
2c216de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Config for multi-device full finetuning in full_finetune_distributed.py | ||
# using a gemma2 27B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did some quick math, I guess this will take at least 216GB total memory (54GB params + 54GB gradients + 108GB optimizer states for AdamW) , which means to run on 4 devices we'd need people to be using A100s. I wonder whether we can use an 8-bit optimizer + optimizer in backward to get us down to a more reasonable peak VRAM here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does 8bit work with distributed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah duh.. there may be some issues with bitsandbytes optimizers on that front. I just tried out ao low-precision optimizers and it seems to work (though haven't resumed from intermediate checkpoint). Also there may be a compile dep there. Anyways if it's too much hassle we can consider it separately, don't wanna increase the scope of this already substantial PR more than necessary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What should I do here? Change something or expect users to change parameters according to their hardware ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry missed this comment before now. I think it's fine to leave this as you have it and revisit these details in a later PR |
||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only when the model is being fine-tuned on 2+ GPUs. | ||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma2-27b/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to potentially be a pain in the ass here. We have parallel PR (#1872) which is helping standardize our configs and better expose the features we have. This means we always have Would it be annoying to ask if we could update these in the same way while we're here, please? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done I have updated all the configs to match the other PR! |
||
_component_: torchtune.datasets.alpaca_dataset | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma2.gemma_27b | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/gemma2-27b/ | ||
checkpoint_files: [ | ||
Optimox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model-00001-of-00024.safetensors, | ||
model-00002-of-00024.safetensors, | ||
model-00003-of-00024.safetensors, | ||
model-00004-of-00024.safetensors, | ||
model-00005-of-00024.safetensors, | ||
model-00006-of-00024.safetensors, | ||
model-00007-of-00024.safetensors, | ||
model-00008-of-00024.safetensors, | ||
model-00009-of-00024.safetensors, | ||
model-00010-of-00024.safetensors, | ||
model-00011-of-00024.safetensors, | ||
model-00012-of-00024.safetensors, | ||
model-00013-of-00024.safetensors, | ||
model-00014-of-00024.safetensors, | ||
model-00015-of-00024.safetensors, | ||
model-00016-of-00024.safetensors, | ||
model-00017-of-00024.safetensors, | ||
model-00018-of-00024.safetensors, | ||
model-00019-of-00024.safetensors, | ||
model-00020-of-00024.safetensors, | ||
model-00021-of-00024.safetensors, | ||
model-00022-of-00024.safetensors, | ||
model-00023-of-00024.safetensors, | ||
model-00024-of-00024.safetensors, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/gemma2-27b | ||
model_type: GEMMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 1 | ||
epochs: 1 | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
lr: 2e-5 | ||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-gemma2-27b-finetune | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Config for multi-device LoRA finetuning in lora_finetune_distributed.py | ||
# using a gemma2 27B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only when the model is being fine-tuned on 2+ GPUs. | ||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma2-27b/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma2.lora_gemma2_27b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: True | ||
lora_rank: 64 | ||
lora_alpha: 128 | ||
lora_dropout: 0.0 | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/gemma2-27b/ | ||
checkpoint_files: [ | ||
model-00001-of-00024.safetensors, | ||
model-00002-of-00024.safetensors, | ||
model-00003-of-00024.safetensors, | ||
model-00004-of-00024.safetensors, | ||
model-00005-of-00024.safetensors, | ||
model-00006-of-00024.safetensors, | ||
model-00007-of-00024.safetensors, | ||
model-00008-of-00024.safetensors, | ||
model-00009-of-00024.safetensors, | ||
model-00010-of-00024.safetensors, | ||
model-00011-of-00024.safetensors, | ||
model-00012-of-00024.safetensors, | ||
model-00013-of-00024.safetensors, | ||
model-00014-of-00024.safetensors, | ||
model-00015-of-00024.safetensors, | ||
model-00016-of-00024.safetensors, | ||
model-00017-of-00024.safetensors, | ||
model-00018-of-00024.safetensors, | ||
model-00019-of-00024.safetensors, | ||
model-00020-of-00024.safetensors, | ||
model-00021-of-00024.safetensors, | ||
model-00022-of-00024.safetensors, | ||
model-00023-of-00024.safetensors, | ||
model-00024-of-00024.safetensors, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/gemma2-27b/ | ||
model_type: GEMMA2 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
lr: 2e-5 | ||
|
||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 10 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
# Fine-tuning arguments | ||
batch_size: 4 | ||
epochs: 3 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-gemma2-27b-lora | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Config for multi-device LoRA finetuning in lora_finetune_single_device.py | ||
# using a gemma2 27B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run (torchtune does not use gguf so you can ignore it to save time and space): | ||
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma2-27b/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma2.lora_gemma2_27b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: True | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
lora_dropout: 0.0 | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/gemma2-27b/ | ||
checkpoint_files: [ | ||
model-00001-of-00024.safetensors, | ||
model-00002-of-00024.safetensors, | ||
model-00003-of-00024.safetensors, | ||
model-00004-of-00024.safetensors, | ||
model-00005-of-00024.safetensors, | ||
model-00006-of-00024.safetensors, | ||
model-00007-of-00024.safetensors, | ||
model-00008-of-00024.safetensors, | ||
model-00009-of-00024.safetensors, | ||
model-00010-of-00024.safetensors, | ||
model-00011-of-00024.safetensors, | ||
model-00012-of-00024.safetensors, | ||
model-00013-of-00024.safetensors, | ||
model-00014-of-00024.safetensors, | ||
model-00015-of-00024.safetensors, | ||
model-00016-of-00024.safetensors, | ||
model-00017-of-00024.safetensors, | ||
model-00018-of-00024.safetensors, | ||
model-00019-of-00024.safetensors, | ||
model-00020-of-00024.safetensors, | ||
model-00021-of-00024.safetensors, | ||
model-00022-of-00024.safetensors, | ||
model-00023-of-00024.safetensors, | ||
model-00024-of-00024.safetensors, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/gemma2-27b/ | ||
model_type: GEMMA2 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
lr: 5e-5 | ||
|
||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 10 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
# Fine-tuning arguments | ||
batch_size: 8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we confident this'll fit on a single device? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed batch size to 2 and accumulation to 8. What is the expected GPU? Is there a CI running everything? Otherwise I guess each user should be responsible to play with the batch to get something suitable for his GPU no ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally we try ship configs which we know will work on some common hardware configuration (see examples here https://github.com/pytorch/torchtune?tab=readme-ov-file#memory-and-training-speed), so users can maintain the expectation that they can get started without any painful OOMs. Then they are free to play with the configs. We should make sure this config works with e.g. 1xA1000 - let me know if you need a hand here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SalmanMohammadi I do not have easy access to a A100, would appreciate if someone could run the code for the 27B params model and let me know what batch size I should set. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll have a quick look when we're ready to land. We can also reasonably mirror the batch size from the config of another similarly sized model already in the codebase. |
||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 2 | ||
compile: False | ||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
enable_activation_offloading: False | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-gemma2-27b-lora | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 5 | ||
active_steps: 2 | ||
num_cycles: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.