Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions examples/qwen-3-32b/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from truss_train import definitions
from truss.base import truss_config

BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"

training_runtime = definitions.Runtime(
start_commands = [
"/bin/sh -c 'chmod +x ./run.sh && ./run.sh'",
],
environment_variables={
"HF_TOKEN": definitions.SecretReference(name="hf_access_token"), # The name of the HF Access Token secret in your B10 account
"WANDB_API_KEY": definitions.SecretReference(name="wandb_api_key"), # The name of the WandB API Key secret in your B10 account
},
checkpointing_config=definitions.CheckpointingConfig(
enabled=False,
),
cache_config=definitions.CacheConfig( # this defines BT_RW_CACHE_DIR
enabled=True,
),
)

training_compute = definitions.Compute(
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=8,
),
node_count = 2
)

training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime
)

training_project = definitions.TrainingProject(
name="Qwen 3 32B Test",
job=training_job
)
83 changes: 83 additions & 0 deletions examples/qwen-3-32b/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Base Model Configuration
base_model: Qwen/Qwen3-32B #HF Repo of base model
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

plugins:
- axolotl.integrations.liger.LigerPlugin
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: false
liger_cross_entropy: false

val_set_size: 0.05

dataset_prepared_path: ./outputs/last_run_prepared
sample_packing: true
pad_to_sequence_len: true

chat_template: qwen3
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:10%]
field_messages: conversations
message_property_mappings:
role: from
content: value

logging_steps: 1
num_epochs: 10
micro_batch_size: 1
gradient_accumulation_steps: 4
# evals_per_epoch: 1
max_grad_norm: 1.0

optimizer: adamw_torch_fused
sequence_len: 20000
learning_rate: 1e-6
adam_beta1: 0.9 # Standard, well-tested value
adam_beta2: 0.999 # Higher beta2 for more stability
adam_epsilon: 1e-8 # Standard epsilon works well with regular AdamW
warmup_ratio: 0.1
weight_decay: 0.01

bf16: true
fp16: false
tf32: false

# FSDP Configuration
fsdp:
- full_shard
- auto_wrap

fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_offload_params: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_sharding_strategy: FULL_SHARD
fsdp_sync_module_states: true
fsdp_cpu_ram_efficient_loading: true
fsdp_use_orig_params: true

gradient_checkpointing: offload

# save_strategy: epoch
# hub_model_id: baseten-admin/test #Where to store the model, need to make one in your account first if your API key has finegrained permissions
# hub_strategy: all_checkpoints #How often to store to HF

# use_wandb: true
# wandb_project: Sample Test #Name of the wandb project
# wandb_entity: philipkiely-baseten

flash_attention: true
# eval_sample_packing: false
sequence_parallel_degree: 16

save_safetensors: false

# max_steps: 5
17 changes: 17 additions & 0 deletions examples/qwen-3-32b/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -eux

pip3 install ring-flash-attn>=0.1.4 #new, along with ^ ring-flash-attn
pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"

sed -i 's/grad_scale = 1 \/ lse\.numel()/grad_scale = 1 \/ lse\.numel() if lse\.numel() else 1\.0/' /root/miniconda3/envs/py3.11/lib/python3.11/site-packages/cut_cross_entropy/cce.py

export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000

huggingface-cli login --token=$HF_TOKEN

axolotl preprocess config.yaml

torchrun --nnodes=$BT_GROUP_SIZE --nproc-per-node=$BT_NUM_GPUS --node-rank=$BT_NODE_RANK --rdzv-backend=c10d --rdzv-id=$BT_TRAINING_JOB_ID --rdzv-endpoint=$BT_LEADER_ADDR:29400 -m axolotl.cli.train config.yaml