-
Notifications
You must be signed in to change notification settings - Fork 414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add gemma2b variants #1835
base: main
Are you sure you want to change the base?
Changes from all commits
9685e0b
e87f878
6f89920
0d53660
6b50916
54a237c
2c216de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Config for multi-device full finetuning in full_finetune_distributed.py | ||
# using a gemma2 27B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did some quick math, I guess this will take at least 216GB total memory (54GB params + 54GB gradients + 108GB optimizer states for AdamW) , which means to run on 4 devices we'd need people to be using A100s. I wonder whether we can use an 8-bit optimizer + optimizer in backward to get us down to a more reasonable peak VRAM here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does 8bit work with distributed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah duh.. there may be some issues with bitsandbytes optimizers on that front. I just tried out ao low-precision optimizers and it seems to work (though haven't resumed from intermediate checkpoint). Also there may be a compile dep there. Anyways if it's too much hassle we can consider it separately, don't wanna increase the scope of this already substantial PR more than necessary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What should I do here? Change something or expect users to change parameters according to their hardware ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry missed this comment before now. I think it's fine to leave this as you have it and revisit these details in a later PR |
||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only when the model is being fine-tuned on 2+ GPUs. | ||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma-2-27b/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to potentially be a pain in the ass here. We have parallel PR (#1872) which is helping standardize our configs and better expose the features we have. This means we always have Would it be annoying to ask if we could update these in the same way while we're here, please? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done I have updated all the configs to match the other PR! |
||
packed: False # Set to true for great speed ups | ||
_component_: torchtune.datasets.alpaca_dataset | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma2.gemma_27b | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/gemma-2-27b/ | ||
checkpoint_files: | ||
filename_format: model-{}-of-{}.safetensors | ||
max_filename: 00024 | ||
recipe_checkpoint: null | ||
output_dir: /tmp/gemma-2-27b | ||
model_type: GEMMA2 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 1 | ||
epochs: 1 | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
lr: 2e-5 | ||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
compile: False # pytorch compile, set to true for perf/memory improvement | ||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-gemma2-27b-finetune | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Config for multi-device LoRA finetuning in lora_finetune_distributed.py | ||
# using a gemma2 27B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only when the model is being fine-tuned on 2+ GPUs. | ||
|
||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma-2-27b/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
packed: False # Set to true for great speed ups | ||
_component_: torchtune.datasets.alpaca_dataset | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma2.lora_gemma2_27b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: True | ||
lora_rank: 64 | ||
lora_alpha: 128 | ||
lora_dropout: 0.0 | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/gemma-2-27b/ | ||
checkpoint_files: | ||
filename_format: model-{}-of-{}.safetensors | ||
max_filename: 00024 | ||
recipe_checkpoint: null | ||
output_dir: /tmp/gemma-2-27b/ | ||
model_type: GEMMA2 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
lr: 2e-5 | ||
|
||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 10 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
# Fine-tuning arguments | ||
batch_size: 4 | ||
epochs: 3 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
compile: False # pytorch compile, set to true for perf/memory improvement | ||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-gemma2-27b-lora | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Config for multi-device LoRA finetuning in lora_finetune_single_device.py | ||
# using a gemma2 27B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run (torchtune does not use gguf so you can ignore it to save time and space): | ||
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works only for training on single device. | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.gemma.gemma_tokenizer | ||
path: /tmp/gemma-2-27b/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
packed: False # Set to true for great speed ups | ||
_component_: torchtune.datasets.alpaca_dataset | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma2.lora_gemma2_27b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: True | ||
lora_rank: 8 | ||
lora_alpha: 16 | ||
lora_dropout: 0.0 | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/gemma-2-27b/ | ||
checkpoint_files: | ||
filename_format: model-{}-of-{}.safetensors | ||
max_filename: 00024 | ||
recipe_checkpoint: null | ||
output_dir: /tmp/gemma-2-27b/ | ||
model_type: GEMMA2 | ||
resume_from_checkpoint: False | ||
save_adapter_weights_only: False | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
fused: True | ||
lr: 5e-5 | ||
|
||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 10 | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 8 | ||
compile: False # pytorch compile, set to true for perf/memory improvement | ||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True | ||
enable_activation_offloading: False | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
output_dir: /tmp/alpaca-gemma2-27b-lora | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: True | ||
|
||
# Show case the usage of pytorch profiler | ||
# Set enabled to False as it's only needed for debugging training | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 5 | ||
active_steps: 2 | ||
num_cycles: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.