Skip to content

Commit

Permalink
WIP: non working flex attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Oct 22, 2024
1 parent e87f878 commit 6f89920
Show file tree
Hide file tree
Showing 8 changed files with 438 additions and 20 deletions.
6 changes: 3 additions & 3 deletions recipes/configs/gemma2/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma2-2b/tokenizer.model
path: /tmp/gemma-2-2b/tokenizer.model

# Dataset
dataset:
Expand All @@ -33,14 +33,14 @@ model:

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma2-2b/
checkpoint_dir: /tmp/gemma-2-2b/
checkpoint_files: [
model-00001-of-00003.safetensors,
model-00002-of-00003.safetensors,
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/gemma2-2b
output_dir: /tmp/gemma-2-2b
model_type: GEMMA2
resume_from_checkpoint: False

Expand Down
6 changes: 3 additions & 3 deletions recipes/configs/gemma2/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma2-2b/tokenizer.model
path: /tmp/gemma-2-2b/tokenizer.model

# Dataset
dataset:
Expand All @@ -37,14 +37,14 @@ model:

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma2-2b/
checkpoint_dir: /tmp/gemma-2-2b/
checkpoint_files: [
model-00001-of-00003.safetensors,
model-00002-of-00003.safetensors,
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/gemma2-2b
output_dir: /tmp/gemma-2-2b
model_type: GEMMA2
resume_from_checkpoint: False

Expand Down
8 changes: 4 additions & 4 deletions recipes/configs/gemma2/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma2-2b/tokenizer.model
path: /tmp/gemma-2-2b/tokenizer.model

# Dataset
dataset:
Expand All @@ -44,7 +44,7 @@ checkpointer:
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/gemma2-2b
output_dir: /tmp/gemma-2-2b
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
Expand All @@ -62,10 +62,10 @@ loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 4
batch_size: 8
epochs: 3
max_steps_per_epoch: null
gradient_accumulation_steps: 4
gradient_accumulation_steps: 2
compile: False

# Training env
Expand Down
6 changes: 3 additions & 3 deletions recipes/configs/gemma2/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma2-2b/tokenizer.model
path: /tmp/gemma-2-2b/tokenizer.model

# Dataset
dataset:
Expand All @@ -37,14 +37,14 @@ model:

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma2-2b/
checkpoint_dir: /tmp/gemma-2-2b/
checkpoint_files: [
model-00001-of-00003.safetensors,
model-00002-of-00003.safetensors,
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/gemma2-2b
output_dir: /tmp/gemma-2-2b
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
Expand Down
1 change: 0 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,6 @@ def save_checkpoint(self, epoch: int) -> None:
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

# run model
with self.activations_handling_ctx:
logits = self._model(**batch)
Expand Down
Loading

0 comments on commit 6f89920

Please sign in to comment.