Qwen-Coder-MCQ: Knowledge Distillation of GPT-4o to the compact LLM Qwen2.5-Coder-1.5B-Instruct for Multiple-Choice Coding Questions
This project implements knowledge distillation of YAML-based structured multi-step reasoning from GPT-4o to the small Qwen2.5-Coder-1.5B-Instruct model for multiple-choice coding questions. It uses LoRA (Low-Rank Adaptation) for efficient finetuning and includes a comprehensive pipeline for data processing, training, and evaluation. The live demo showcases the model with 4-bit quantization, making it highly efficient while maintaining reasoning capabilities.
🌐 Live Demo: Try the model in action with our Hugging Face Space: tuandunghcmut/Qwen2.5_Coder_1.5B_Instruct_MCQs_v4_2
📢 Dataset Available: The high-quality synthesis dataset containing 3,549 coding multiple-choice questions with detailed teacher explanations is now available on HuggingFace: tuandunghcmut/coding-mcq-reasoning. Each question includes structured YAML reasoning steps generated by GPT-4o, making it ideal for training and evaluating coding question-answering models.
🔥 Model Checkpoint: The best model checkpoint from this project is now available on HuggingFace: tuandunghcmut/Qwen25_Coder_MultipleChoice_v4. Wandb Link for experiment tracking and result reproducing: wandb link.
📝 Interactive Notebook: Try out the model with our interactive Google Colab notebook:
![]()
![]()
. Please refer to the directory
notebooks
for the latest version of the notebooks.
- Quick Start
- Prompt Formats
- Command-Line Interface
- Highlighted Features
- Dataset
- Advanced Features
- Examples and Showcase
- System Architecture
- Clone the repository:
git clone https://github.com/tuandung222/Small-Qwen-Coding-Multiple-Choice.git
cd Small-Qwen-Coding-Multiple-Choice
- Install dependencies:
pip install -r requirements.txt
Set up environment variables in either of two ways:
- Using a
.env
file (recommended):
# Copy the example .env file
cp .env.example .env
# Edit the .env file with your API keys
nano .env
Required variables in .env
:
HF_TOKEN=your_huggingface_token_here
WANDB_API_KEY=your_wandb_api_key_here
OPENAI_API_KEY=your_openai_api_key_here # Required for teacher synthesis
- Using environment variables directly:
export HF_TOKEN=your_huggingface_token_here
export WANDB_API_KEY=your_wandb_api_key_here
export OPENAI_API_KEY=your_openai_api_key_here
- Training the Model:
# Basic training
python src/run.py
# Advanced training with custom parameters
python src/run.py \
--experiment-name "my_experiment" \
--source-model "unsloth/Qwen2.5-Coder-1.5B-Instruct" \
--epochs 5 \
--batch-size 16 \
--learning-rate 1e-4
The train.sh
script provides a convenient way to start training with optimized default settings and proper environment configuration. Here's how to use it:
- Make the script executable:
chmod +x train.sh
- Configure training parameters (optional):
Edit the following variables in
train.sh
to customize your training:
# Model configuration
SOURCE_MODEL="unsloth/Qwen2.5-Coder-1.5B-Instruct"
DESTINATION_REPO="your-username/your-model-name"
# Training hyperparameters
BATCH_SIZE=16
GRAD_ACCUM=2
LEARNING_RATE=5e-5
EPOCHS=3
WARMUP_STEPS=20
VALIDATION_STEPS=30
DEBUG_SAMPLES=3
# Validation settings
MINIMAL_VALIDATING=true
MAX_VALIDATION_SAMPLES=90
SAVE_STEPS=60
SAVE_TOTAL_LIMIT=5
- Start training:
./train.sh
The script includes several optimizations and features:
- Proper environment variable setup for stability
- Automatic PYTHONPATH configuration
- Process cleanup before starting
- Logging to
train.log
- Lion 8-bit optimizer settings
- Flash Attention 2 integration
- Gradient checkpointing
- WandB logging integration
- Automatic model pushing to Hub
- Monitor training:
# View live training logs
tail -f train.log
# Monitor GPU usage
watch -n 1 nvidia-smi
- Generating Synthetic Explanations:
# Basic synthesis
python src/data_synthesis/gpt4o_generated.py \
--model gpt-4o \
--data-path /path/to/dataset
# Advanced synthesis with options
python src/data_synthesis/gpt4o_generated.py \
--model gpt-4o \
--data-path /path/to/dataset \
--sample-size 100 \
--temperature 0.2 \
--max-tokens 2048 \
--concurrent-requests 5
The project uses two distinct YAML-formatted prompts: one for student reasoning during inference and another for teacher synthesis during training data generation.
Used during model inference to encourage structured thinking:
Question: [question text]
Choices:
A. [choice 1]
B. [choice 2]
C. [choice 3]
D. [choice 4]
Think through this step-by-step:
- Understand what the question is asking
- Analyze each option carefully
- Reason about why each option might be correct or incorrect
- Select the most appropriate answer
Your response MUST be in YAML format:
understanding: |
<your understanding of the question>
analysis: |
<your analysis of each option>
reasoning: |
<your reasoning about the correct answer>
conclusion: |
<your final conclusion>
answer: <single letter A through D>
Used for generating high-quality training data:
TASK: You are a teacher creating a concise, precise explanation for a multiple-choice question.
QUESTION:
[question text]
CHOICES:
A. [choice 1]
B. [choice 2]
C. [choice 3]
D. [choice 4]
CORRECT ANSWER: [correct_answer]
INSTRUCTIONS:
Create a focused explanation that demonstrates why [correct_answer] is correct
and why other options are incorrect. Be thorough but concise.
Your response MUST be in YAML format:
understanding: |
<brief explanation of key concepts>
analysis: |
<concise analysis of each option>
reasoning: |
<focused reasoning for correct answer>
conclusion: |
<brief summary>
answer: [correct_answer]
Key differences between formats:
- Knowledge of Answer: Student format encourages exploration, teacher format explains known answer
- Focus: Student format emphasizes step-by-step thinking, teacher format prioritizes conciseness
- Purpose: Student format for inference, teacher format for generating training data
- Style: Student format is exploratory, teacher format is authoritative
The project provides comprehensive command-line interfaces for both training and synthesis tasks. Below are the detailed arguments and their usage:
python src/run.py [arguments]
Argument | Description | Default | Example |
---|---|---|---|
--source-model |
Base model to fine-tune | unsloth/Qwen2.5-Coder-1.5B-Instruct | --source-model "your-model/name" |
--destination-repo |
HF Hub repo for saving | tuandunghcmut/Qwen25_Coder_MultipleChoice_v3 | --destination-repo "your-username/repo-name" |
--max-seq-length |
Maximum sequence length | 2048 | --max-seq-length 4096 |
--quantization |
Model quantization level | 4bit | --quantization "8bit" |
Argument | Description | Default | Example |
---|---|---|---|
--epochs |
Number of training epochs | 3 | --epochs 5 |
--batch-size |
Per device batch size | 24 | --batch-size 32 |
--grad-accum |
Gradient accumulation steps | 4 | --grad-accum 8 |
--learning-rate |
Learning rate | 2e-4 | --learning-rate 1e-4 |
--warmup-ratio |
Warmup steps ratio | 0.1 | --warmup-ratio 0.2 |
--weight-decay |
Weight decay for optimizer | 0.01 | --weight-decay 0.1 |
Argument | Description | Default | Example |
---|---|---|---|
--lora-r |
LoRA attention dimension | 8 | --lora-r 16 |
--lora-alpha |
LoRA alpha parameter | 32 | --lora-alpha 64 |
--lora-dropout |
LoRA dropout rate | 0.05 | --lora-dropout 0.1 |
--target-modules |
Modules to apply LoRA | q_proj,k_proj,v_proj,o_proj,... | --target-modules "q_proj,v_proj" |
Argument | Description | Default | Example |
---|---|---|---|
--output-dir |
Directory for outputs | ./model_output | --output-dir "./my_experiment" |
--experiment-name |
Name for experiment | timestamp | --experiment-name "lora_test_1" |
--save-steps |
Steps between saves | 500 | --save-steps 1000 |
--logging-steps |
Steps between logs | 100 | --logging-steps 50 |
Argument | Description | Default | Example |
---|---|---|---|
--validation-steps |
Steps between validations | 50 | --validation-steps 100 |
--metric-for-best |
Metric to track for best model | eval_loss | --metric-for-best "eval_accuracy" |
--greater-is-better |
Whether higher is better | false | --greater-is-better |
--validate-at-start |
Run validation before training | false | --validate-at-start |
--early-stopping-patience |
Epochs without improvement before stopping | 3 | --early-stopping-patience 5 |
--early-stopping-delta |
Minimum change to count as improvement | 0.0 | --early-stopping-delta 0.01 |
--val-split |
Fraction of data for validation | 0.04 | --val-split 0.1 |
--push-to-hub |
Push models to HuggingFace Hub | false | --push-to-hub |
- Basic Training:
python src/run.py \
--source-model "unsloth/Qwen2.5-Coder-1.5B-Instruct" \
--epochs 3 \
--batch-size 24 \
--learning-rate 2e-4
- Advanced Training with LoRA:
python src/run.py \
--experiment-name "lora_experiment" \
--source-model "unsloth/Qwen2.5-Coder-1.5B-Instruct" \
--epochs 5 \
--batch-size 32 \
--learning-rate 1e-4 \
--lora-r 16 \
--lora-alpha 64 \
--warmup-ratio 0.2 \
--weight-decay 0.01 \
--max-seq-length 2048 \
--quantization "4bit"
python src/data_synthesis/gpt4o_generated.py [arguments]
Argument | Description | Default | Example |
---|---|---|---|
--model |
OpenAI model to use | gpt-4o | --model "gpt-3.5-turbo" |
--temperature |
Generation temperature | 0.2 | --temperature 0.7 |
--max-tokens |
Maximum tokens per response | 2048 | --max-tokens 4096 |
--api-key |
OpenAI API key | None | --api-key "sk-..." |
Argument | Description | Default | Example |
---|---|---|---|
--data-path |
Path to dataset | ./data/train | --data-path "./my_data" |
--sample-size |
Number of examples | None (all) | --sample-size 100 |
--random-seed |
Random seed | 42 | --random-seed 123 |
--concurrent-requests |
Parallel API requests | 5 | --concurrent-requests 10 |
Argument | Description | Default | Example |
---|---|---|---|
--output-dir |
Directory for outputs | ./synthesis_results | --output-dir "./results" |
--quiet |
Suppress verbose output | False | --quiet |
- Basic Synthesis:
python src/data_synthesis/gpt4o_generated.py \
--model "gpt-4o" \
--data-path "/path/to/dataset" \
--api-key "your-api-key"
- Advanced Synthesis:
python src/data_synthesis/gpt4o_generated.py \
--model "gpt-4o" \
--data-path "/path/to/dataset" \
--sample-size 100 \
--temperature 0.2 \
--max-tokens 2048 \
--concurrent-requests 5 \
--output-dir "./synthesis_results" \
--random-seed 42
-
Test Modes:
--test-mode
: Use only 2 examples for quick testing--test-training-mode
: Use one batch for minimal training testing
-
Hub Integration:
--push-strategy
: Choose when to push to hub (best/end/all/no)--private
: Make the repository private
-
Advanced Training:
--train-on-responses-only
: Focus training on responses--use-flash-attention
: Enable Flash Attention 2--attention-implementation
: Choose attention implementation
Here's a comprehensive example showcasing all available features:
#!/bin/bash
# Set environment variables for performance and stability
export PYTHONHASHSEED=42
export CUDA_LAUNCH_BLOCKING=0
export PYTHONPATH=$PYTHONPATH:$(pwd)
export TOKENIZERS_PARALLELISM=true
export CUDA_VISIBLE_DEVICES=0
export OMP_NUM_THREADS=4
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_DEBUG=INFO
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
# Make sure we're in the correct directory and Python can find the modules
cd "$(dirname "$0")"
export PYTHONPATH="$PYTHONPATH:$(pwd)"
echo "Setting PYTHONPATH to include: $(pwd)"
# LoRA-optimized default values with Lion 8-bit settings
SOURCE_MODEL="unsloth/Qwen2.5-Coder-1.5B-Instruct"
DESTINATION_REPO="tuandunghcmut/Qwen25_Coder_MultipleChoice_v4"
BATCH_SIZE=12
GRAD_ACCUM=2
LEARNING_RATE=5e-5
EPOCHS=7
WARMUP_STEPS=30
LOGGING_STEPS=30
VALIDATION_STEPS=30
DEBUG_SAMPLES=3
MINIMAL_VALIDATING=true
MAX_VALIDATION_SAMPLES=90
SAVE_STEPS=60
SAVE_TOTAL_LIMIT=5
# Data loading configuration
DATALOADER_NUM_WORKERS=4
DATALOADER_PIN_MEMORY=true
FULL_DETERMINISM=false
TORCH_COMPILE=false
USE_CPU=false
# Evaluation configuration
EVAL_STRATEGY="steps"
REPORT_TO="wandb"
REMOVE_UNUSED_COLUMNS=false
# Add timestamp to experiment name for uniqueness
TIMESTAMP=$(date +"%m%d_%H%M")
EXPERIMENT_NAME="Qwen25_Coder_MCQ_LoRA_${TIMESTAMP}"
# Run the training script with comprehensive features
nohup python3 src/run.py \
--experiment-name "${EXPERIMENT_NAME}" \
--source-model "$SOURCE_MODEL" \
--destination-repo "$DESTINATION_REPO" \
--epochs "$EPOCHS" \
--batch-size "$BATCH_SIZE" \
--grad-accum "$GRAD_ACCUM" \
--optimizer "lion_8bit" \
--learning-rate "$LEARNING_RATE" \
--weight-decay 0.1 \
--lion-beta1 0.95 \
--lion-beta2 0.98 \
--adam-epsilon 1e-8 \
--max-grad-norm 0.3 \
--warmup-steps "$WARMUP_STEPS" \
--lr-scheduler "cosine" \
--lr-scheduler-num-cycles 1 \
--validation-steps "$VALIDATION_STEPS" \
--minimal-validating \
--max-validation-samples "$MAX_VALIDATION_SAMPLES" \
--validate-at-start \
--metric-for-best "eval_loss" \
--early-stopping-patience 7 \
--early-stopping-delta 0.01 \
--save-steps "$SAVE_STEPS" \
--save-total-limit "$SAVE_TOTAL_LIMIT" \
--save-strategy "steps" \
--no-load-best-model-at-end \
--lora-r 16 \
--lora-alpha 16 \
--lora-dropout 0.05 \
--target-modules "q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj" \
--debug-samples "$DEBUG_SAMPLES" \
--dataloader-num-workers "$DATALOADER_NUM_WORKERS" \
--dataloader-pin-memory \
--full-determinism \
--torch-compile \
--use-cpu \
--evaluation-strategy "$EVAL_STRATEGY" \
--report-to "$REPORT_TO" \
--remove-unused-columns \
--push-to-hub \
--logging-steps "$LOGGING_STEPS" \
--max-seq-length 2048 \
--prompt-template "teacher_reasoned" \
--push-strategy "best" \
--dataset "tuandunghcmut/coding-mcq-reasoning" \
--val-split 0.05 \
--random-seed 42 \
--output-dir "model_output" \
--use-gradient-checkpointing \
--use-flash-attention \
--attention-implementation "flash_attention_2" \
--force-attn-implementation \
--train-on-responses-only \
--instruction-token "<|im_start|>user\n" \
--response-token "<|im_start|>assistant\n" | tee -a train.log
This comprehensive example includes:
-
Environment Setup
- Fixed random seeds for reproducibility
- CUDA launch blocking for better error tracking
-
Training Configuration
- 7 epochs with batch size 12
- 4-bit quantization
- Gradient accumulation in 2 steps
- Cosine learning rate schedule
-
LoRA Settings
- Rank 16 with alpha 16
- Comprehensive module targeting
- Optimized dropout
-
Optimization
- 8-bit Lion optimizer
- Gradient clipping
- Early stopping
- Regular validation
-
Advanced Features
- Flash Attention 2
- Response-only training
- Prompt monitoring and analysis
- Automatic model pushing
-
Monitoring
- Regular logging
- Checkpoint management
- Debug samples
- Comprehensive logging
Save this as train.sh
, make it executable with chmod +x train.sh
, and run with ./train.sh
.
- LoRA (Low-Rank Adaptation) with configurable parameters
- Gradient checkpointing for memory efficiency
- Unsloth integration for faster training and reduced memory usage
- Multiple attention implementations (Flash Attention 2, SDPA, xFormers)
- Mixed precision training (FP16/BF16)
- Gradient accumulation for effective batch size control
- Multiple optimizer options (adamw_torch, adam8bit, pagedadam, lion, adafactor)
- Configurable learning rate schedulers (cosine, linear, polynomial, etc.)
- Warmup strategies with customizable ratios
- Gradient clipping and weight decay
- YAML-format outputs for clear reasoning steps
- Multiple prompt templates for different approaches
- Teacher-reasoned training methodology
- Response-only training option for focused learning
- Multiple evaluation metrics
- Validation strategies with configurable frequency
- Best model checkpointing
- Early stopping with customizable patience
- Real-time display of random training prompts
- Token distribution analysis and visualization
- Prompt diversity tracking with similarity metrics
- Quality metrics (length, complexity, readability)
- Automatic prompt categorization
- Interactive prompt selection and comparison
- WandB integration for prompt analytics
- Configurable logging frequency
- Learning rate tracking
- Model loading alerts
- GPU memory and gradient monitoring
- WandB integration for experiment tracking
- Automatic repository creation
- Configurable push strategies
- Support for private repositories
- Multiple save formats (LoRA, merged 16bit, merged 4bit, GGUF)
- Test modes for rapid iteration
- Debug sampling for data inspection
- Comprehensive logging
- Flexible configuration via CLI
The project uses a curated dataset of multiple-choice coding questions with structured reasoning, published at tuandunghcmut/coding-mcq-reasoning.
The dataset contains 3,549 selected coding multiple-choice questions derived from the CodeMMLU benchmark, enriched with detailed reasoning steps provided by a GPT-4o teacher model. Each example includes:
- Task ID: Unique identifier for each question
- Question: The coding problem or concept being tested
- Choices: Multiple choice answers (A, B, C, D, etc.)
- Answer: The correct option
- Teacher Understanding: Detailed breakdown of the problem statement
- Teacher Analysis: Systematic evaluation of each option
- Teacher Reasoning: Step-by-step logical process
- Teacher Conclusion: Final explanation of the correct answer
- YAML String: Structured format of the reasoning process
The project uses two distinct YAML-formatted prompts: one for student reasoning during inference and another for teacher synthesis during training data generation.
This format is used during model inference, encouraging structured thinking without knowledge of the correct answer:
Question: [question text]
Choices:
A. [choice 1]
B. [choice 2]
C. [choice 3]
D. [choice 4]
Think through this step-by-step:
- Understand what the question is asking
- Analyze each option carefully
- Reason about why each option might be correct or incorrect
- Select the most appropriate answer
Your response MUST be in YAML format:
understanding: |
<your understanding of the question>
analysis: |
<your analysis of each option>
reasoning: |
<your reasoning about the correct answer>
conclusion: |
<your final conclusion>
answer: <single letter A through D>
This format is used to generate high-quality training data, where the model acts as a teacher with knowledge of the correct answer:
TASK: You are a teacher creating a concise, precise explanation for a multiple-choice question.
QUESTION:
[question text]
CHOICES:
A. [choice 1]
B. [choice 2]
C. [choice 3]
D. [choice 4]
CORRECT ANSWER: [correct_answer]
INSTRUCTIONS:
Create a focused explanation that demonstrates why [correct_answer] is correct
and why other options are incorrect. Be thorough but concise.
Your response MUST be in YAML format:
understanding: |
<brief explanation of key concepts>
analysis: |
<concise analysis of each option>
reasoning: |
<focused reasoning for correct answer>
conclusion: |
<brief summary>
answer: [correct_answer]
Key differences between the formats:
- Knowledge of Answer: Student format encourages exploration, while teacher format focuses on explaining the known correct answer
- Focus: Student format emphasizes step-by-step thinking, teacher format prioritizes concise, precise explanations
- Purpose: Student format for inference, teacher format for generating training data
- Style: Student format is exploratory, teacher format is authoritative and educational
The structured YAML format ensures:
- Consistent formatting across responses
- Easy parsing and validation
- Clear separation of reasoning components
- Systematic approach to problem-solving and explanation
Here are some examples from the dataset:
Task ID: k08183
Question: What does the SQL function "ROUND()" do?
Choices: ['Rounds a number to the nearest integer', 'Concatenates two or more strings', 'Converts a string to lowercase', 'Returns the length of a string']
Answer: A
Teacher Understanding: The question is asking about the purpose of the SQL function "ROUND()". This function is used in SQL to manipulate numerical data, specifically to adjust the precision of numbers.
Teacher Analysis:
- A. Rounds a number to the nearest integer: This is correct. The "ROUND()" function is used to round a numeric value to the nearest integer or to a specified number of decimal places.
- B. Concatenates two or more strings: This is incorrect. Concatenation of strings is typically done using the "CONCAT()" function in SQL.
- C. Converts a string to lowercase: This is incorrect. Converting a string to lowercase is done using the "LOWER()" function in SQL.
- D. Returns the length of a string: This is incorrect. The length of a string is determined using the "LENGTH()" function in SQL.
Teacher Reasoning: The "ROUND()" function is specifically designed to handle numerical values by rounding them to the nearest integer or specified decimal places, which aligns with option A. The other options describe functions that manipulate strings, not numbers.
Teacher Conclusion: Answer A is correct because the "ROUND()" function's primary purpose is to round numbers, which is distinct from the string operations described in the other options.
Task ID: k08183
Question: Given a sequence of rolls of a k-sided dice, what is the length of the shortest sequence that cannot be formed?
Choices: ['ans += k - len(seen) + 1', 'ans += 1', 'ans = min(ans + 1, k)', 'ans = ans + 1']
Answer: B
Teacher Understanding: The problem asks for the length of the shortest sequence that cannot be formed from the given rolls of a k-sided dice. The solution involves tracking unique rolls and incrementing a counter when all k numbers have been seen.
Teacher Analysis:
- A. This option incorrectly adjusts the answer based on the difference between k and the size of the set, which is unnecessary since the goal is to increment when all k numbers are seen.
- B. This option correctly increments the answer by 1 when all k numbers have been seen, indicating a complete sequence.
- C. This option uses the min function, which is unnecessary and incorrect because the answer should simply increment by 1 when all k numbers are seen.
- D. This option is similar to B but is redundant because it doesn't add any new logic beyond incrementing by 1.
Teacher Reasoning: The solution needs to increment the sequence count (ans) each time a complete set of k unique numbers is seen. Option B correctly increments the count by 1 when the set size equals k, which signifies that a complete sequence of k numbers has been formed and another sequence can start.
Teacher Conclusion: Answer B is correct because it directly and correctly increments the sequence count by 1 when all k numbers have been seen, aligning with the problem's requirement to find the shortest sequence that cannot be formed.
The framework includes a comprehensive prompt monitoring system that logs and analyzes prompts during training, providing valuable insights into your training data:
python src/run.py \
--logging-steps 100 \
--prompt-track-diversity \
--prompt-track-quality \
--prompt-categorize \
--prompt-comparison \
--max-prompts-to-save 200
- Token Analysis: Analyzes token distributions, unique tokens, and token entropy
- Quality Metrics: Tracks prompt quality over time, including complexity and coherence
- Diversity Tracking: Monitors the diversity of prompts to ensure varied training
- Category Distribution: Automatically categorizes prompts for better insights
- WandB Integration: Rich visualizations in WandB including tables, charts, and trends
- Prompt Comparison: Ability to compare different prompts during training
After running training with prompt monitoring enabled, you can view detailed prompt metrics in your WandB dashboard:
- Navigate to your WandB project
- Select your training run
- Check the "prompts" section in the dashboard
- View various charts including:
- Token distribution
- Prompt length trends
- Category distribution
- Quality metrics over time
- Diversity scores
This helps you better understand your training data and identify potential issues or biases during training.
The framework includes a comprehensive model card generation system that creates detailed, informative model cards when pushing to Hugging Face Hub:
python src/run.py \
--push-to-hub \
--destination-repo "your-username/model-name"
- Automatic Validation Metrics: Includes detailed validation metrics (eval_loss, runtime, samples/second)
- WandB Integration: Automatically embeds a direct link to the WandB experiment dashboard
- Example Completions: Shows sample outputs generated during validation
- Training Details: Lists comprehensive training hyperparameters and configuration
- Framework Versions: Documents versions of key libraries (Transformers, PyTorch, PEFT)
- Dataset Information: Includes details about the training and validation datasets
- Usage Examples: Provides code snippets for easy model usage
The automatically generated model card includes:
- Model Performance: Key validation metrics with precise formatting
- Model Description: Details about model capabilities and architecture
- Training Data: Information about dataset size and characteristics
- Training Procedure: Complete hyperparameters and configuration
- Experiment Tracking: Direct link to WandB dashboard
- Example Completions: Sample model outputs during validation
- Usage Guide: Ready-to-use code snippets for inference
- Limitations: Documentation of potential model constraints
This feature ensures that models pushed to Hugging Face Hub are well-documented, making them more accessible and easier to use for the community.
The training process includes progress bars using tqdm for better visibility:
# Training with visible progress bars
python src/run.py --validation-steps 50
This provides real-time feedback on both training and validation progress, making it easier to monitor long-running training jobs.
The training pipeline now includes comprehensive memory profiling through the MemoryProfilingCallback
:
from src.training.callbacks import MemoryProfilingCallback
memory_callback = MemoryProfilingCallback(
log_every_n_steps=100,
detailed_profiling=True,
warning_threshold=0.90,
track_fragmentation=True,
log_to_file=True,
output_dir="your_output_dir"
)
trainer = Trainer(
model=model,
callbacks=[memory_callback],
# ... other trainer arguments
)
-
CUDA Memory Tracking:
{ "cuda_allocated": 3.45, # GB "cuda_reserved": 4.12, # GB "cuda_max_allocated": 5.67, "cuda_max_reserved": 6.01 }
-
Memory Fragmentation Detection:
{ "fragmentation_ratio": 0.15, # 15% fragmentation "warning_threshold": 0.90, # 90% usage warning }
-
GPU Utilization Monitoring:
{ "gpu_utilization": 85, # Percentage "gpu_memory_utilization": 78 # Percentage }
-
Memory Leak Detection:
- Monitors garbage collector activity
- Alerts on suspicious memory patterns
- Tracks memory growth over time
Enhanced WandB integration with structured logging and visualization:
from src.training.callbacks import WandBConfig, WandBLogger, WandBCallback
# Configure WandB logging
config = WandBConfig(
project_name="my_project",
run_name="experiment_1",
log_memory=True,
log_gradients=True,
log_training=True,
log_validation=True,
log_examples=True,
log_interval=100,
example_batch_size=5
)
# Initialize logger and callback
logger = WandBLogger(config)
wandb_callback = WandBCallback(logger)
-
Training Metrics:
{ "training/loss": loss_value, "training/learning_rate": current_lr, "training/gradient_norm": grad_norm, "training/parameter_norm": param_norm }
-
Memory Metrics:
{ "memory/cuda/allocated_gb": allocated, "memory/cuda/reserved_gb": reserved, "memory/fragmentation_ratio": frag_ratio }
-
Model Information:
{ "model/total_parameters": total_params, "model/trainable_parameters": trainable_params, "model/frozen_parameters": frozen_params }
-
Example Logging:
wandb.log({ "examples/val_step": example_table, "examples/correct_count": correct_count, "examples/accuracy": accuracy })
-
Gradient Checkpointing:
training_args = TrainingArguments( gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False} )
-
Memory-Efficient Training:
# Use 8-bit optimizers optimizer_config = { "optimizer_type": "lion_8bit", "weight_decay": 0.1, "optim_bits": 8 } # Enable gradient accumulation training_args = TrainingArguments( gradient_accumulation_steps=8, per_device_train_batch_size=4 )
-
Automatic Memory Management:
- Regular cache clearing
- Fragmentation monitoring
- OOM prevention warnings
- Automatic batch size adjustment
The WandB dashboard includes:
-
Memory Usage Panels:
- CUDA memory allocation
- Memory fragmentation
- GPU utilization
- Memory leak detection
-
Training Progress:
- Loss curves
- Learning rate schedules
- Gradient statistics
- Parameter norms
-
Example Visualization:
- Training samples
- Model predictions
- Accuracy metrics
- Quality analysis
-
System Metrics:
- CPU usage
- Disk I/O
- Network traffic
- Process memory
-
Memory Management:
# Clear cache at strategic points torch.cuda.empty_cache() # Monitor fragmentation if fragmentation_ratio > 0.3: logger.warning("High memory fragmentation detected")
-
Gradient Handling:
# Clip gradients for stability training_args = TrainingArguments( max_grad_norm=0.3, max_grad_clip=1.0 )
-
Batch Size Optimization:
# Start with small batch size training_args = TrainingArguments( per_device_train_batch_size=4, gradient_accumulation_steps=8 )
-
Regular Checkpointing:
training_args = TrainingArguments( save_steps=30, save_total_limit=5 )
For more detailed information about memory profiling and monitoring, refer to the Memory Profiling Guide and WandB Integration Guide.
graph TB
classDef pipeline fill:#e1f5fe,stroke:#01579b,stroke-width:2px
classDef component fill:#fff3e0,stroke:#ff6f00,stroke-width:2px
classDef output fill:#f1f8e9,stroke:#33691e,stroke-width:2px
subgraph Data["Data Pipeline"]
D1[("CodeMMLU<br/>Dataset")]:::component --> D2["Data Processing<br/>(YAML Format)"]:::component
D2 --> D3["Teacher Synthesis<br/>(GPT-4o/3.5)"]:::component
D3 --> D4[("Processed MCQ<br/>Dataset")]:::output
end
subgraph Training["Training Pipeline"]
T1[("Qwen2.5<br/>Base Model")]:::component --> T2["LoRA Fine-tuning<br/>(8-bit Quantization)"]:::component
D4 --> T2
T2 --> T3["Fine-tuned Model<br/>(LoRA Weights)"]:::output
T3 --> T4["Model Evaluation<br/>(Accuracy/Loss)"]:::component
end
subgraph Monitoring["Monitoring & Callbacks"]
M1["WandB Logger<br/>(Real-time)"]:::component --> M2["Metrics Tracking<br/>(Loss/Accuracy)"]:::component
M2 --> M3["Early Stopping<br/>(Patience: 3)"]:::component
M2 --> M4["Validation<br/>(10% Split)"]:::component
M2 --> M5["Prompt Monitor<br/>(Quality/Diversity)"]:::component
end
Data --> Training
Training --> Monitoring
class Data,Training,Monitoring pipeline
graph LR
classDef process fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
classDef data fill:#fff3e0,stroke:#ef6c00,stroke-width:2px
classDef output fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
classDef monitor fill:#fce4ec,stroke:#c2185b,stroke-width:2px
subgraph Input["Input Processing"]
I1[("Raw MCQ<br/>Data")]:::data --> I2["Task Queue<br/>(Batch Size: 5)"]:::process
I2 --> I3["Concurrent<br/>Processing"]:::process
end
subgraph Synthesis["Synthesis Process"]
S1["GPT-4o/3.5<br/>API Calls"]:::process --> S2["YAML<br/>Generation"]:::process
S2 --> S3["Answer<br/>Verification"]:::process
S3 --> S4["Quality<br/>Check"]:::process
end
subgraph Output["Output & Monitoring"]
O1["Save<br/>Explanations"]:::output --> O2["Calculate<br/>Metrics"]:::monitor
O2 --> O3["Track<br/>Progress"]:::monitor
O3 --> O4["WandB<br/>Logging"]:::monitor
end
Input --> Synthesis
Synthesis --> Output
style Input fill:#f8f9fa,stroke:#343a40,stroke-width:2px
style Synthesis fill:#f8f9fa,stroke:#343a40,stroke-width:2px
style Output fill:#f8f9fa,stroke:#343a40,stroke-width:2px
graph TB
classDef input fill:#e8eaf6,stroke:#283593,stroke-width:2px
classDef process fill:#fff3e0,stroke:#e65100,stroke-width:2px
classDef validation fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px
classDef output fill:#e8f5e9,stroke:#1b5e20,stroke-width:2px
subgraph Input["Input Data Processing"]
D1[("CodeMMLU<br/>Dataset")]:::input --> D2["Question<br/>Extraction"]:::process
D2 --> D3["Choice<br/>Formatting"]:::process
D3 --> D4["Token<br/>Encoding"]:::process
end
subgraph Processing["Data Enhancement"]
P1["YAML<br/>Formatting"]:::process --> P2["Token<br/>Analysis"]:::process
P2 --> P3["Quality<br/>Metrics"]:::process
P3 --> P4["Diversity<br/>Tracking"]:::process
end
subgraph Validation["Quality Control"]
V1["Answer<br/>Preservation"]:::validation --> V2["Format<br/>Verification"]:::validation
V2 --> V3["Metrics<br/>Logging"]:::validation
V3 --> V4["Error<br/>Handling"]:::validation
end
subgraph Output["Data Output"]
O1["Save<br/>Dataset"]:::output --> O2["Generate<br/>Statistics"]:::output
O2 --> O3["Create<br/>Visualizations"]:::output
end
Input --> Processing
Processing --> Validation
Validation --> Output
style Input fill:#f8f9fa,stroke:#343a40,stroke-width:2px
style Processing fill:#f8f9fa,stroke:#343a40,stroke-width:2px
style Validation fill:#f8f9fa,stroke:#343a40,stroke-width:2px
style Output fill:#f8f9fa,stroke:#343a40,stroke-width:2px
classDiagram
class QwenTrainer {
+model: PreTrainedModel
+tokenizer: PreTrainedTokenizer
+prompt_creator: PromptCreator
+train(dataset, args)
+evaluate(dataset)
+save_checkpoint(path)
+push_to_hub(repo_id)
-setup_optimizer()
-setup_scheduler()
}
class PromptCreator {
+YAML_REASONING: str
+TEACHER_REASONED: str
+BASIC: str
+create_inference_prompt(question, choices)
+create_training_prompt(question, choices)
-format_choices(choices)
-validate_format(prompt)
}
class TeacherSynthesisFramework {
+model_config: ModelConfig
+output_dir: str
+concurrent_requests: int
+generate_synthetic_explanation()
+process_dataset(dataset)
+_calculate_metrics()
-_save_results()
-_handle_errors()
}
class Callbacks {
+ValidationCallback
+EarlyStoppingCallback
+PromptMonitorCallback
+LRMonitorCallback
+ModelLoadingCallback
-track_metrics()
-log_to_wandb()
}
class ModelConfig {
+name: str
+temperature: float
+max_tokens: int
+api_key: str
+validate()
+to_dict()
}
QwenTrainer --> PromptCreator: uses
QwenTrainer --> Callbacks: manages
TeacherSynthesisFramework --> PromptCreator: uses
TeacherSynthesisFramework --> ModelConfig: configures
Callbacks --> QwenTrainer: monitors
note for QwenTrainer "Main training orchestrator"
note for PromptCreator "Handles prompt generation"
note for TeacherSynthesisFramework "Manages synthetic data"
note for Callbacks "Monitors training process"
graph LR
classDef metrics fill:#e1f5fe,stroke:#0277bd,stroke-width:2px
classDef callbacks fill:#fff3e0,stroke:#ef6c00,stroke-width:2px
classDef viz fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
subgraph Metrics["Metrics Collection"]
M1["Training Loss<br/>(Per Step)"]:::metrics --> M4["WandB<br/>Logger"]:::metrics
M2["Validation<br/>Metrics"]:::metrics --> M4
M3["Prompt<br/>Quality"]:::metrics --> M4
end
subgraph Callbacks["Training Control"]
C1["Early<br/>Stopping"]:::callbacks --> C4["Training<br/>Control"]:::callbacks
C2["Learning Rate<br/>Monitor"]:::callbacks --> C4
C3["Prompt<br/>Monitor"]:::callbacks --> C4
end
subgraph Visualization["Analytics Dashboard"]
V1["Loss<br/>Curves"]:::viz --> V4["WandB<br/>Dashboard"]:::viz
V2["Prompt<br/>Statistics"]:::viz --> V4
V3["Model<br/>Performance"]:::viz --> V4
end
Metrics --> Callbacks
Callbacks --> Visualization
style Metrics fill:#f8f9fa,stroke:#343a40,stroke-width:2px
style Callbacks fill:#f8f9fa,stroke:#343a40,stroke-width:2px
style Visualization fill:#f8f9fa,stroke:#343a40,stroke-width:2px