Skip to content

Knowledge Distillation About YAML-Based Structured Multi-Step Reasoning from a Teacher Model GPT-4o to a Small LLM: Qwen2.5-Coder-1.5B-Instruct

Notifications You must be signed in to change notification settings

tuandung222/Small-Qwen-Coding-Multiple-Choice

Repository files navigation

Qwen-Coder-MCQ: Knowledge Distillation of GPT-4o to the compact LLM Qwen2.5-Coder-1.5B-Instruct for Multiple-Choice Coding Questions

Python Version PyTorch Version License Flash Attention Unsloth

This project implements knowledge distillation of YAML-based structured multi-step reasoning from GPT-4o to the small Qwen2.5-Coder-1.5B-Instruct model for multiple-choice coding questions. It uses LoRA (Low-Rank Adaptation) for efficient finetuning and includes a comprehensive pipeline for data processing, training, and evaluation. The live demo showcases the model with 4-bit quantization, making it highly efficient while maintaining reasoning capabilities.

🌐 Live Demo: Try the model in action with our Hugging Face Space: tuandunghcmut/Qwen2.5_Coder_1.5B_Instruct_MCQs_v4_2

📢 Dataset Available: The high-quality synthesis dataset containing 3,549 coding multiple-choice questions with detailed teacher explanations is now available on HuggingFace: tuandunghcmut/coding-mcq-reasoning. Each question includes structured YAML reasoning steps generated by GPT-4o, making it ideal for training and evaluating coding question-answering models.

🔥 Model Checkpoint: The best model checkpoint from this project is now available on HuggingFace: tuandunghcmut/Qwen25_Coder_MultipleChoice_v4. Wandb Link for experiment tracking and result reproducing: wandb link.

📝 Interactive Notebook: Try out the model with our interactive Google Colab notebook: Open In Colab Open In Colab Open In Colab. Please refer to the directory notebooks for the latest version of the notebooks.

📑 Table of Contents

🚀 Quick Start

Installation

  1. Clone the repository:
git clone https://github.com/tuandung222/Small-Qwen-Coding-Multiple-Choice.git
cd Small-Qwen-Coding-Multiple-Choice
  1. Install dependencies:
pip install -r requirements.txt

Environment Setup

Set up environment variables in either of two ways:

  1. Using a .env file (recommended):
# Copy the example .env file
cp .env.example .env

# Edit the .env file with your API keys
nano .env

Required variables in .env:

HF_TOKEN=your_huggingface_token_here
WANDB_API_KEY=your_wandb_api_key_here
OPENAI_API_KEY=your_openai_api_key_here  # Required for teacher synthesis
  1. Using environment variables directly:
export HF_TOKEN=your_huggingface_token_here
export WANDB_API_KEY=your_wandb_api_key_here
export OPENAI_API_KEY=your_openai_api_key_here

Basic Usage

  1. Training the Model:
# Basic training
python src/run.py

# Advanced training with custom parameters
python src/run.py \
  --experiment-name "my_experiment" \
  --source-model "unsloth/Qwen2.5-Coder-1.5B-Instruct" \
  --epochs 5 \
  --batch-size 16 \
  --learning-rate 1e-4

Using train.sh Script

The train.sh script provides a convenient way to start training with optimized default settings and proper environment configuration. Here's how to use it:

  1. Make the script executable:
chmod +x train.sh
  1. Configure training parameters (optional): Edit the following variables in train.sh to customize your training:
# Model configuration
SOURCE_MODEL="unsloth/Qwen2.5-Coder-1.5B-Instruct"
DESTINATION_REPO="your-username/your-model-name"

# Training hyperparameters
BATCH_SIZE=16
GRAD_ACCUM=2
LEARNING_RATE=5e-5
EPOCHS=3
WARMUP_STEPS=20
VALIDATION_STEPS=30
DEBUG_SAMPLES=3

# Validation settings
MINIMAL_VALIDATING=true
MAX_VALIDATION_SAMPLES=90
SAVE_STEPS=60
SAVE_TOTAL_LIMIT=5
  1. Start training:
./train.sh

The script includes several optimizations and features:

  • Proper environment variable setup for stability
  • Automatic PYTHONPATH configuration
  • Process cleanup before starting
  • Logging to train.log
  • Lion 8-bit optimizer settings
  • Flash Attention 2 integration
  • Gradient checkpointing
  • WandB logging integration
  • Automatic model pushing to Hub
  1. Monitor training:
# View live training logs
tail -f train.log

# Monitor GPU usage
watch -n 1 nvidia-smi
  1. Generating Synthetic Explanations:
# Basic synthesis
python src/data_synthesis/gpt4o_generated.py \
  --model gpt-4o \
  --data-path /path/to/dataset

# Advanced synthesis with options
python src/data_synthesis/gpt4o_generated.py \
  --model gpt-4o \
  --data-path /path/to/dataset \
  --sample-size 100 \
  --temperature 0.2 \
  --max-tokens 2048 \
  --concurrent-requests 5

📝 Prompt Formats

The project uses two distinct YAML-formatted prompts: one for student reasoning during inference and another for teacher synthesis during training data generation.

Student Reasoning Format

Used during model inference to encourage structured thinking:

Question: [question text]

Choices:
A. [choice 1]
B. [choice 2]
C. [choice 3]
D. [choice 4]

Think through this step-by-step:
- Understand what the question is asking
- Analyze each option carefully
- Reason about why each option might be correct or incorrect
- Select the most appropriate answer

Your response MUST be in YAML format:
understanding: |
  <your understanding of the question>
analysis: |
  <your analysis of each option>
reasoning: |
  <your reasoning about the correct answer>
conclusion: |
  <your final conclusion>
answer: <single letter A through D>

Teacher Synthesis Format

Used for generating high-quality training data:

TASK: You are a teacher creating a concise, precise explanation for a multiple-choice question.

QUESTION:
[question text]

CHOICES:
A. [choice 1]
B. [choice 2]
C. [choice 3]
D. [choice 4]

CORRECT ANSWER: [correct_answer]

INSTRUCTIONS:
Create a focused explanation that demonstrates why [correct_answer] is correct
and why other options are incorrect. Be thorough but concise.

Your response MUST be in YAML format:
understanding: |
  <brief explanation of key concepts>
analysis: |
  <concise analysis of each option>
reasoning: |
  <focused reasoning for correct answer>
conclusion: |
  <brief summary>
answer: [correct_answer]

Key differences between formats:

  1. Knowledge of Answer: Student format encourages exploration, teacher format explains known answer
  2. Focus: Student format emphasizes step-by-step thinking, teacher format prioritizes conciseness
  3. Purpose: Student format for inference, teacher format for generating training data
  4. Style: Student format is exploratory, teacher format is authoritative

⚙️ Command-Line Interface

The project provides comprehensive command-line interfaces for both training and synthesis tasks. Below are the detailed arguments and their usage:

Training Script (src/run.py)

python src/run.py [arguments]

Model Configuration

Argument Description Default Example
--source-model Base model to fine-tune unsloth/Qwen2.5-Coder-1.5B-Instruct --source-model "your-model/name"
--destination-repo HF Hub repo for saving tuandunghcmut/Qwen25_Coder_MultipleChoice_v3 --destination-repo "your-username/repo-name"
--max-seq-length Maximum sequence length 2048 --max-seq-length 4096
--quantization Model quantization level 4bit --quantization "8bit"

Training Parameters

Argument Description Default Example
--epochs Number of training epochs 3 --epochs 5
--batch-size Per device batch size 24 --batch-size 32
--grad-accum Gradient accumulation steps 4 --grad-accum 8
--learning-rate Learning rate 2e-4 --learning-rate 1e-4
--warmup-ratio Warmup steps ratio 0.1 --warmup-ratio 0.2
--weight-decay Weight decay for optimizer 0.01 --weight-decay 0.1

LoRA Configuration

Argument Description Default Example
--lora-r LoRA attention dimension 8 --lora-r 16
--lora-alpha LoRA alpha parameter 32 --lora-alpha 64
--lora-dropout LoRA dropout rate 0.05 --lora-dropout 0.1
--target-modules Modules to apply LoRA q_proj,k_proj,v_proj,o_proj,... --target-modules "q_proj,v_proj"

Output & Monitoring

Argument Description Default Example
--output-dir Directory for outputs ./model_output --output-dir "./my_experiment"
--experiment-name Name for experiment timestamp --experiment-name "lora_test_1"
--save-steps Steps between saves 500 --save-steps 1000
--logging-steps Steps between logs 100 --logging-steps 50

Validation Parameters

Argument Description Default Example
--validation-steps Steps between validations 50 --validation-steps 100
--metric-for-best Metric to track for best model eval_loss --metric-for-best "eval_accuracy"
--greater-is-better Whether higher is better false --greater-is-better
--validate-at-start Run validation before training false --validate-at-start
--early-stopping-patience Epochs without improvement before stopping 3 --early-stopping-patience 5
--early-stopping-delta Minimum change to count as improvement 0.0 --early-stopping-delta 0.01
--val-split Fraction of data for validation 0.04 --val-split 0.1
--push-to-hub Push models to HuggingFace Hub false --push-to-hub

Example Commands

  1. Basic Training:
python src/run.py \
    --source-model "unsloth/Qwen2.5-Coder-1.5B-Instruct" \
    --epochs 3 \
    --batch-size 24 \
    --learning-rate 2e-4
  1. Advanced Training with LoRA:
python src/run.py \
    --experiment-name "lora_experiment" \
    --source-model "unsloth/Qwen2.5-Coder-1.5B-Instruct" \
    --epochs 5 \
    --batch-size 32 \
    --learning-rate 1e-4 \
    --lora-r 16 \
    --lora-alpha 64 \
    --warmup-ratio 0.2 \
    --weight-decay 0.01 \
    --max-seq-length 2048 \
    --quantization "4bit"

Synthesis Script (src/data_synthesis/gpt4o_generated.py)

python src/data_synthesis/gpt4o_generated.py [arguments]

Model Configuration

Argument Description Default Example
--model OpenAI model to use gpt-4o --model "gpt-3.5-turbo"
--temperature Generation temperature 0.2 --temperature 0.7
--max-tokens Maximum tokens per response 2048 --max-tokens 4096
--api-key OpenAI API key None --api-key "sk-..."

Data Processing

Argument Description Default Example
--data-path Path to dataset ./data/train --data-path "./my_data"
--sample-size Number of examples None (all) --sample-size 100
--random-seed Random seed 42 --random-seed 123
--concurrent-requests Parallel API requests 5 --concurrent-requests 10

Output Configuration

Argument Description Default Example
--output-dir Directory for outputs ./synthesis_results --output-dir "./results"
--quiet Suppress verbose output False --quiet

Example Commands

  1. Basic Synthesis:
python src/data_synthesis/gpt4o_generated.py \
    --model "gpt-4o" \
    --data-path "/path/to/dataset" \
    --api-key "your-api-key"
  1. Advanced Synthesis:
python src/data_synthesis/gpt4o_generated.py \
    --model "gpt-4o" \
    --data-path "/path/to/dataset" \
    --sample-size 100 \
    --temperature 0.2 \
    --max-tokens 2048 \
    --concurrent-requests 5 \
    --output-dir "./synthesis_results" \
    --random-seed 42

Additional Features

  1. Test Modes:

    • --test-mode: Use only 2 examples for quick testing
    • --test-training-mode: Use one batch for minimal training testing
  2. Hub Integration:

    • --push-strategy: Choose when to push to hub (best/end/all/no)
    • --private: Make the repository private
  3. Advanced Training:

    • --train-on-responses-only: Focus training on responses
    • --use-flash-attention: Enable Flash Attention 2
    • --attention-implementation: Choose attention implementation

Advanced Training with All Features

Here's a comprehensive example showcasing all available features:

#!/bin/bash

# Set environment variables for performance and stability
export PYTHONHASHSEED=42
export CUDA_LAUNCH_BLOCKING=0
export PYTHONPATH=$PYTHONPATH:$(pwd)
export TOKENIZERS_PARALLELISM=true
export CUDA_VISIBLE_DEVICES=0
export OMP_NUM_THREADS=4
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_DEBUG=INFO
export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512

# Make sure we're in the correct directory and Python can find the modules
cd "$(dirname "$0")"
export PYTHONPATH="$PYTHONPATH:$(pwd)"
echo "Setting PYTHONPATH to include: $(pwd)"

# LoRA-optimized default values with Lion 8-bit settings
SOURCE_MODEL="unsloth/Qwen2.5-Coder-1.5B-Instruct"
DESTINATION_REPO="tuandunghcmut/Qwen25_Coder_MultipleChoice_v4"
BATCH_SIZE=12
GRAD_ACCUM=2
LEARNING_RATE=5e-5
EPOCHS=7
WARMUP_STEPS=30
LOGGING_STEPS=30
VALIDATION_STEPS=30
DEBUG_SAMPLES=3
MINIMAL_VALIDATING=true
MAX_VALIDATION_SAMPLES=90
SAVE_STEPS=60
SAVE_TOTAL_LIMIT=5

# Data loading configuration
DATALOADER_NUM_WORKERS=4
DATALOADER_PIN_MEMORY=true
FULL_DETERMINISM=false
TORCH_COMPILE=false
USE_CPU=false

# Evaluation configuration
EVAL_STRATEGY="steps"
REPORT_TO="wandb"
REMOVE_UNUSED_COLUMNS=false

# Add timestamp to experiment name for uniqueness
TIMESTAMP=$(date +"%m%d_%H%M")
EXPERIMENT_NAME="Qwen25_Coder_MCQ_LoRA_${TIMESTAMP}"

# Run the training script with comprehensive features
nohup python3 src/run.py \
    --experiment-name "${EXPERIMENT_NAME}" \
    --source-model "$SOURCE_MODEL" \
    --destination-repo "$DESTINATION_REPO" \
    --epochs "$EPOCHS" \
    --batch-size "$BATCH_SIZE" \
    --grad-accum "$GRAD_ACCUM" \
    --optimizer "lion_8bit" \
    --learning-rate "$LEARNING_RATE" \
    --weight-decay 0.1 \
    --lion-beta1 0.95 \
    --lion-beta2 0.98 \
    --adam-epsilon 1e-8 \
    --max-grad-norm 0.3 \
    --warmup-steps "$WARMUP_STEPS" \
    --lr-scheduler "cosine" \
    --lr-scheduler-num-cycles 1 \
    --validation-steps "$VALIDATION_STEPS" \
    --minimal-validating \
    --max-validation-samples "$MAX_VALIDATION_SAMPLES" \
    --validate-at-start \
    --metric-for-best "eval_loss" \
    --early-stopping-patience 7 \
    --early-stopping-delta 0.01 \
    --save-steps "$SAVE_STEPS" \
    --save-total-limit "$SAVE_TOTAL_LIMIT" \
    --save-strategy "steps" \
    --no-load-best-model-at-end \
    --lora-r 16 \
    --lora-alpha 16 \
    --lora-dropout 0.05 \
    --target-modules "q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj" \
    --debug-samples "$DEBUG_SAMPLES" \
    --dataloader-num-workers "$DATALOADER_NUM_WORKERS" \
    --dataloader-pin-memory \
    --full-determinism \
    --torch-compile \
    --use-cpu \
    --evaluation-strategy "$EVAL_STRATEGY" \
    --report-to "$REPORT_TO" \
    --remove-unused-columns \
    --push-to-hub \
    --logging-steps "$LOGGING_STEPS" \
    --max-seq-length 2048 \
    --prompt-template "teacher_reasoned" \
    --push-strategy "best" \
    --dataset "tuandunghcmut/coding-mcq-reasoning" \
    --val-split 0.05 \
    --random-seed 42 \
    --output-dir "model_output" \
    --use-gradient-checkpointing \
    --use-flash-attention \
    --attention-implementation "flash_attention_2" \
    --force-attn-implementation \
    --train-on-responses-only \
    --instruction-token "<|im_start|>user\n" \
    --response-token "<|im_start|>assistant\n" | tee -a train.log

This comprehensive example includes:

  1. Environment Setup

    • Fixed random seeds for reproducibility
    • CUDA launch blocking for better error tracking
  2. Training Configuration

    • 7 epochs with batch size 12
    • 4-bit quantization
    • Gradient accumulation in 2 steps
    • Cosine learning rate schedule
  3. LoRA Settings

    • Rank 16 with alpha 16
    • Comprehensive module targeting
    • Optimized dropout
  4. Optimization

    • 8-bit Lion optimizer
    • Gradient clipping
    • Early stopping
    • Regular validation
  5. Advanced Features

    • Flash Attention 2
    • Response-only training
    • Prompt monitoring and analysis
    • Automatic model pushing
  6. Monitoring

    • Regular logging
    • Checkpoint management
    • Debug samples
    • Comprehensive logging

Save this as train.sh, make it executable with chmod +x train.sh, and run with ./train.sh.

🎯 Features

Parameter-Efficient Fine-Tuning

  • LoRA (Low-Rank Adaptation) with configurable parameters
  • Gradient checkpointing for memory efficiency

Optimized Training

  • Unsloth integration for faster training and reduced memory usage
  • Multiple attention implementations (Flash Attention 2, SDPA, xFormers)
  • Mixed precision training (FP16/BF16)
  • Gradient accumulation for effective batch size control

Advanced Optimizers and Schedulers

  • Multiple optimizer options (adamw_torch, adam8bit, pagedadam, lion, adafactor)
  • Configurable learning rate schedulers (cosine, linear, polynomial, etc.)
  • Warmup strategies with customizable ratios
  • Gradient clipping and weight decay

Structured Reasoning

  • YAML-format outputs for clear reasoning steps
  • Multiple prompt templates for different approaches
  • Teacher-reasoned training methodology
  • Response-only training option for focused learning

Comprehensive Evaluation

  • Multiple evaluation metrics
  • Validation strategies with configurable frequency
  • Best model checkpointing
  • Early stopping with customizable patience

Advanced Monitoring

Prompt Monitoring

  • Real-time display of random training prompts
  • Token distribution analysis and visualization
  • Prompt diversity tracking with similarity metrics
  • Quality metrics (length, complexity, readability)
  • Automatic prompt categorization
  • Interactive prompt selection and comparison
  • WandB integration for prompt analytics
  • Configurable logging frequency

Training Metrics

  • Learning rate tracking
  • Model loading alerts
  • GPU memory and gradient monitoring
  • WandB integration for experiment tracking

HuggingFace Hub Integration

  • Automatic repository creation
  • Configurable push strategies
  • Support for private repositories
  • Multiple save formats (LoRA, merged 16bit, merged 4bit, GGUF)

Development and Testing

  • Test modes for rapid iteration
  • Debug sampling for data inspection
  • Comprehensive logging
  • Flexible configuration via CLI

📊 Dataset

The project uses a curated dataset of multiple-choice coding questions with structured reasoning, published at tuandunghcmut/coding-mcq-reasoning.

Dataset Structure

The dataset contains 3,549 selected coding multiple-choice questions derived from the CodeMMLU benchmark, enriched with detailed reasoning steps provided by a GPT-4o teacher model. Each example includes:

  • Task ID: Unique identifier for each question
  • Question: The coding problem or concept being tested
  • Choices: Multiple choice answers (A, B, C, D, etc.)
  • Answer: The correct option
  • Teacher Understanding: Detailed breakdown of the problem statement
  • Teacher Analysis: Systematic evaluation of each option
  • Teacher Reasoning: Step-by-step logical process
  • Teacher Conclusion: Final explanation of the correct answer
  • YAML String: Structured format of the reasoning process

Prompt Formats

The project uses two distinct YAML-formatted prompts: one for student reasoning during inference and another for teacher synthesis during training data generation.

Student Reasoning Format

This format is used during model inference, encouraging structured thinking without knowledge of the correct answer:

Question: [question text]

Choices:
A. [choice 1]
B. [choice 2]
C. [choice 3]
D. [choice 4]

Think through this step-by-step:
- Understand what the question is asking
- Analyze each option carefully
- Reason about why each option might be correct or incorrect
- Select the most appropriate answer

Your response MUST be in YAML format:
understanding: |
  <your understanding of the question>
analysis: |
  <your analysis of each option>
reasoning: |
  <your reasoning about the correct answer>
conclusion: |
  <your final conclusion>
answer: <single letter A through D>

Teacher Synthesis Format

This format is used to generate high-quality training data, where the model acts as a teacher with knowledge of the correct answer:

TASK: You are a teacher creating a concise, precise explanation for a multiple-choice question.

QUESTION:
[question text]

CHOICES:
A. [choice 1]
B. [choice 2]
C. [choice 3]
D. [choice 4]

CORRECT ANSWER: [correct_answer]

INSTRUCTIONS:
Create a focused explanation that demonstrates why [correct_answer] is correct
and why other options are incorrect. Be thorough but concise.

Your response MUST be in YAML format:
understanding: |
  <brief explanation of key concepts>
analysis: |
  <concise analysis of each option>
reasoning: |
  <focused reasoning for correct answer>
conclusion: |
  <brief summary>
answer: [correct_answer]

Key differences between the formats:

  1. Knowledge of Answer: Student format encourages exploration, while teacher format focuses on explaining the known correct answer
  2. Focus: Student format emphasizes step-by-step thinking, teacher format prioritizes concise, precise explanations
  3. Purpose: Student format for inference, teacher format for generating training data
  4. Style: Student format is exploratory, teacher format is authoritative and educational

The structured YAML format ensures:

  • Consistent formatting across responses
  • Easy parsing and validation
  • Clear separation of reasoning components
  • Systematic approach to problem-solving and explanation

Data Examples

Here are some examples from the dataset:

Example 1: SQL Function Question

Task ID: k08183
Question: What does the SQL function "ROUND()" do?
Choices: ['Rounds a number to the nearest integer', 'Concatenates two or more strings', 'Converts a string to lowercase', 'Returns the length of a string']
Answer: A
Teacher Understanding: The question is asking about the purpose of the SQL function "ROUND()". This function is used in SQL to manipulate numerical data, specifically to adjust the precision of numbers.
Teacher Analysis:
- A. Rounds a number to the nearest integer: This is correct. The "ROUND()" function is used to round a numeric value to the nearest integer or to a specified number of decimal places.
- B. Concatenates two or more strings: This is incorrect. Concatenation of strings is typically done using the "CONCAT()" function in SQL.
- C. Converts a string to lowercase: This is incorrect. Converting a string to lowercase is done using the "LOWER()" function in SQL.
- D. Returns the length of a string: This is incorrect. The length of a string is determined using the "LENGTH()" function in SQL.
Teacher Reasoning: The "ROUND()" function is specifically designed to handle numerical values by rounding them to the nearest integer or specified decimal places, which aligns with option A. The other options describe functions that manipulate strings, not numbers.
Teacher Conclusion: Answer A is correct because the "ROUND()" function's primary purpose is to round numbers, which is distinct from the string operations described in the other options.

Example 2: Algorithm Problem

Task ID: k08183
Question: Given a sequence of rolls of a k-sided dice, what is the length of the shortest sequence that cannot be formed?
Choices: ['ans += k - len(seen) + 1', 'ans += 1', 'ans = min(ans + 1, k)', 'ans = ans + 1']
Answer: B
Teacher Understanding: The problem asks for the length of the shortest sequence that cannot be formed from the given rolls of a k-sided dice. The solution involves tracking unique rolls and incrementing a counter when all k numbers have been seen.
Teacher Analysis:
- A. This option incorrectly adjusts the answer based on the difference between k and the size of the set, which is unnecessary since the goal is to increment when all k numbers are seen.
- B. This option correctly increments the answer by 1 when all k numbers have been seen, indicating a complete sequence.
- C. This option uses the min function, which is unnecessary and incorrect because the answer should simply increment by 1 when all k numbers are seen.
- D. This option is similar to B but is redundant because it doesn't add any new logic beyond incrementing by 1.
Teacher Reasoning: The solution needs to increment the sequence count (ans) each time a complete set of k unique numbers is seen. Option B correctly increments the count by 1 when the set size equals k, which signifies that a complete sequence of k numbers has been formed and another sequence can start.
Teacher Conclusion: Answer B is correct because it directly and correctly increments the sequence count by 1 when all k numbers have been seen, aligning with the problem's requirement to find the shortest sequence that cannot be formed.

🔍 Advanced Features

Prompt Monitoring System

The framework includes a comprehensive prompt monitoring system that logs and analyzes prompts during training, providing valuable insights into your training data:

python src/run.py \
    --logging-steps 100 \
    --prompt-track-diversity \
    --prompt-track-quality \
    --prompt-categorize \
    --prompt-comparison \
    --max-prompts-to-save 200

Prompt Monitoring Features

  • Token Analysis: Analyzes token distributions, unique tokens, and token entropy
  • Quality Metrics: Tracks prompt quality over time, including complexity and coherence
  • Diversity Tracking: Monitors the diversity of prompts to ensure varied training
  • Category Distribution: Automatically categorizes prompts for better insights
  • WandB Integration: Rich visualizations in WandB including tables, charts, and trends
  • Prompt Comparison: Ability to compare different prompts during training

Viewing Prompt Metrics in WandB

After running training with prompt monitoring enabled, you can view detailed prompt metrics in your WandB dashboard:

  1. Navigate to your WandB project
  2. Select your training run
  3. Check the "prompts" section in the dashboard
  4. View various charts including:
    • Token distribution
    • Prompt length trends
    • Category distribution
    • Quality metrics over time
    • Diversity scores

This helps you better understand your training data and identify potential issues or biases during training.

Automatic Model Card Generation

The framework includes a comprehensive model card generation system that creates detailed, informative model cards when pushing to Hugging Face Hub:

python src/run.py \
    --push-to-hub \
    --destination-repo "your-username/model-name"

Model Card Features

  • Automatic Validation Metrics: Includes detailed validation metrics (eval_loss, runtime, samples/second)
  • WandB Integration: Automatically embeds a direct link to the WandB experiment dashboard
  • Example Completions: Shows sample outputs generated during validation
  • Training Details: Lists comprehensive training hyperparameters and configuration
  • Framework Versions: Documents versions of key libraries (Transformers, PyTorch, PEFT)
  • Dataset Information: Includes details about the training and validation datasets
  • Usage Examples: Provides code snippets for easy model usage

Model Card Sections

The automatically generated model card includes:

  1. Model Performance: Key validation metrics with precise formatting
  2. Model Description: Details about model capabilities and architecture
  3. Training Data: Information about dataset size and characteristics
  4. Training Procedure: Complete hyperparameters and configuration
  5. Experiment Tracking: Direct link to WandB dashboard
  6. Example Completions: Sample model outputs during validation
  7. Usage Guide: Ready-to-use code snippets for inference
  8. Limitations: Documentation of potential model constraints

This feature ensures that models pushed to Hugging Face Hub are well-documented, making them more accessible and easier to use for the community.

Visualization with tqdm

The training process includes progress bars using tqdm for better visibility:

# Training with visible progress bars
python src/run.py --validation-steps 50

This provides real-time feedback on both training and validation progress, making it easier to monitor long-running training jobs.

🔍 Memory Profiling and Monitoring

Memory Profiling Callback

The training pipeline now includes comprehensive memory profiling through the MemoryProfilingCallback:

from src.training.callbacks import MemoryProfilingCallback

memory_callback = MemoryProfilingCallback(
    log_every_n_steps=100,
    detailed_profiling=True,
    warning_threshold=0.90,
    track_fragmentation=True,
    log_to_file=True,
    output_dir="your_output_dir"
)

trainer = Trainer(
    model=model,
    callbacks=[memory_callback],
    # ... other trainer arguments
)

Features

  1. CUDA Memory Tracking:

    {
        "cuda_allocated": 3.45,  # GB
        "cuda_reserved": 4.12,   # GB
        "cuda_max_allocated": 5.67,
        "cuda_max_reserved": 6.01
    }
  2. Memory Fragmentation Detection:

    {
        "fragmentation_ratio": 0.15,  # 15% fragmentation
        "warning_threshold": 0.90,    # 90% usage warning
    }
  3. GPU Utilization Monitoring:

    {
        "gpu_utilization": 85,       # Percentage
        "gpu_memory_utilization": 78  # Percentage
    }
  4. Memory Leak Detection:

    • Monitors garbage collector activity
    • Alerts on suspicious memory patterns
    • Tracks memory growth over time

Weights & Biases Integration

Enhanced WandB integration with structured logging and visualization:

from src.training.callbacks import WandBConfig, WandBLogger, WandBCallback

# Configure WandB logging
config = WandBConfig(
    project_name="my_project",
    run_name="experiment_1",
    log_memory=True,
    log_gradients=True,
    log_training=True,
    log_validation=True,
    log_examples=True,
    log_interval=100,
    example_batch_size=5
)

# Initialize logger and callback
logger = WandBLogger(config)
wandb_callback = WandBCallback(logger)

Logging Features

  1. Training Metrics:

    {
        "training/loss": loss_value,
        "training/learning_rate": current_lr,
        "training/gradient_norm": grad_norm,
        "training/parameter_norm": param_norm
    }
  2. Memory Metrics:

    {
        "memory/cuda/allocated_gb": allocated,
        "memory/cuda/reserved_gb": reserved,
        "memory/fragmentation_ratio": frag_ratio
    }
  3. Model Information:

    {
        "model/total_parameters": total_params,
        "model/trainable_parameters": trainable_params,
        "model/frozen_parameters": frozen_params
    }
  4. Example Logging:

    wandb.log({
        "examples/val_step": example_table,
        "examples/correct_count": correct_count,
        "examples/accuracy": accuracy
    })

Memory Optimization Tips

  1. Gradient Checkpointing:

    training_args = TrainingArguments(
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False}
    )
  2. Memory-Efficient Training:

    # Use 8-bit optimizers
    optimizer_config = {
        "optimizer_type": "lion_8bit",
        "weight_decay": 0.1,
        "optim_bits": 8
    }
    
    # Enable gradient accumulation
    training_args = TrainingArguments(
        gradient_accumulation_steps=8,
        per_device_train_batch_size=4
    )
  3. Automatic Memory Management:

    • Regular cache clearing
    • Fragmentation monitoring
    • OOM prevention warnings
    • Automatic batch size adjustment

Monitoring Dashboard

The WandB dashboard includes:

  1. Memory Usage Panels:

    • CUDA memory allocation
    • Memory fragmentation
    • GPU utilization
    • Memory leak detection
  2. Training Progress:

    • Loss curves
    • Learning rate schedules
    • Gradient statistics
    • Parameter norms
  3. Example Visualization:

    • Training samples
    • Model predictions
    • Accuracy metrics
    • Quality analysis
  4. System Metrics:

    • CPU usage
    • Disk I/O
    • Network traffic
    • Process memory

Best Practices

  1. Memory Management:

    # Clear cache at strategic points
    torch.cuda.empty_cache()
    
    # Monitor fragmentation
    if fragmentation_ratio > 0.3:
        logger.warning("High memory fragmentation detected")
  2. Gradient Handling:

    # Clip gradients for stability
    training_args = TrainingArguments(
        max_grad_norm=0.3,
        max_grad_clip=1.0
    )
  3. Batch Size Optimization:

    # Start with small batch size
    training_args = TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8
    )
  4. Regular Checkpointing:

    training_args = TrainingArguments(
        save_steps=30,
        save_total_limit=5
    )

For more detailed information about memory profiling and monitoring, refer to the Memory Profiling Guide and WandB Integration Guide.

🏗️ Architecture

Overall System Architecture

graph TB
    classDef pipeline fill:#e1f5fe,stroke:#01579b,stroke-width:2px
    classDef component fill:#fff3e0,stroke:#ff6f00,stroke-width:2px
    classDef output fill:#f1f8e9,stroke:#33691e,stroke-width:2px

    subgraph Data["Data Pipeline"]
        D1[("CodeMMLU<br/>Dataset")]:::component --> D2["Data Processing<br/>(YAML Format)"]:::component
        D2 --> D3["Teacher Synthesis<br/>(GPT-4o/3.5)"]:::component
        D3 --> D4[("Processed MCQ<br/>Dataset")]:::output
    end

    subgraph Training["Training Pipeline"]
        T1[("Qwen2.5<br/>Base Model")]:::component --> T2["LoRA Fine-tuning<br/>(8-bit Quantization)"]:::component
        D4 --> T2
        T2 --> T3["Fine-tuned Model<br/>(LoRA Weights)"]:::output
        T3 --> T4["Model Evaluation<br/>(Accuracy/Loss)"]:::component
    end

    subgraph Monitoring["Monitoring & Callbacks"]
        M1["WandB Logger<br/>(Real-time)"]:::component --> M2["Metrics Tracking<br/>(Loss/Accuracy)"]:::component
        M2 --> M3["Early Stopping<br/>(Patience: 3)"]:::component
        M2 --> M4["Validation<br/>(10% Split)"]:::component
        M2 --> M5["Prompt Monitor<br/>(Quality/Diversity)"]:::component
    end

    Data --> Training
    Training --> Monitoring

    class Data,Training,Monitoring pipeline
Loading

Teacher Synthesis Pipeline

graph LR
    classDef process fill:#e3f2fd,stroke:#1565c0,stroke-width:2px
    classDef data fill:#fff3e0,stroke:#ef6c00,stroke-width:2px
    classDef output fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px
    classDef monitor fill:#fce4ec,stroke:#c2185b,stroke-width:2px

    subgraph Input["Input Processing"]
        I1[("Raw MCQ<br/>Data")]:::data --> I2["Task Queue<br/>(Batch Size: 5)"]:::process
        I2 --> I3["Concurrent<br/>Processing"]:::process
    end

    subgraph Synthesis["Synthesis Process"]
        S1["GPT-4o/3.5<br/>API Calls"]:::process --> S2["YAML<br/>Generation"]:::process
        S2 --> S3["Answer<br/>Verification"]:::process
        S3 --> S4["Quality<br/>Check"]:::process
    end

    subgraph Output["Output & Monitoring"]
        O1["Save<br/>Explanations"]:::output --> O2["Calculate<br/>Metrics"]:::monitor
        O2 --> O3["Track<br/>Progress"]:::monitor
        O3 --> O4["WandB<br/>Logging"]:::monitor
    end

    Input --> Synthesis
    Synthesis --> Output

    style Input fill:#f8f9fa,stroke:#343a40,stroke-width:2px
    style Synthesis fill:#f8f9fa,stroke:#343a40,stroke-width:2px
    style Output fill:#f8f9fa,stroke:#343a40,stroke-width:2px
Loading

Data Processing Pipeline

graph TB
    classDef input fill:#e8eaf6,stroke:#283593,stroke-width:2px
    classDef process fill:#fff3e0,stroke:#e65100,stroke-width:2px
    classDef validation fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px
    classDef output fill:#e8f5e9,stroke:#1b5e20,stroke-width:2px

    subgraph Input["Input Data Processing"]
        D1[("CodeMMLU<br/>Dataset")]:::input --> D2["Question<br/>Extraction"]:::process
        D2 --> D3["Choice<br/>Formatting"]:::process
        D3 --> D4["Token<br/>Encoding"]:::process
    end

    subgraph Processing["Data Enhancement"]
        P1["YAML<br/>Formatting"]:::process --> P2["Token<br/>Analysis"]:::process
        P2 --> P3["Quality<br/>Metrics"]:::process
        P3 --> P4["Diversity<br/>Tracking"]:::process
    end

    subgraph Validation["Quality Control"]
        V1["Answer<br/>Preservation"]:::validation --> V2["Format<br/>Verification"]:::validation
        V2 --> V3["Metrics<br/>Logging"]:::validation
        V3 --> V4["Error<br/>Handling"]:::validation
    end

    subgraph Output["Data Output"]
        O1["Save<br/>Dataset"]:::output --> O2["Generate<br/>Statistics"]:::output
        O2 --> O3["Create<br/>Visualizations"]:::output
    end

    Input --> Processing
    Processing --> Validation
    Validation --> Output

    style Input fill:#f8f9fa,stroke:#343a40,stroke-width:2px
    style Processing fill:#f8f9fa,stroke:#343a40,stroke-width:2px
    style Validation fill:#f8f9fa,stroke:#343a40,stroke-width:2px
    style Output fill:#f8f9fa,stroke:#343a40,stroke-width:2px
Loading

Training Architecture

classDiagram
    class QwenTrainer {
        +model: PreTrainedModel
        +tokenizer: PreTrainedTokenizer
        +prompt_creator: PromptCreator
        +train(dataset, args)
        +evaluate(dataset)
        +save_checkpoint(path)
        +push_to_hub(repo_id)
        -setup_optimizer()
        -setup_scheduler()
    }

    class PromptCreator {
        +YAML_REASONING: str
        +TEACHER_REASONED: str
        +BASIC: str
        +create_inference_prompt(question, choices)
        +create_training_prompt(question, choices)
        -format_choices(choices)
        -validate_format(prompt)
    }

    class TeacherSynthesisFramework {
        +model_config: ModelConfig
        +output_dir: str
        +concurrent_requests: int
        +generate_synthetic_explanation()
        +process_dataset(dataset)
        +_calculate_metrics()
        -_save_results()
        -_handle_errors()
    }

    class Callbacks {
        +ValidationCallback
        +EarlyStoppingCallback
        +PromptMonitorCallback
        +LRMonitorCallback
        +ModelLoadingCallback
        -track_metrics()
        -log_to_wandb()
    }

    class ModelConfig {
        +name: str
        +temperature: float
        +max_tokens: int
        +api_key: str
        +validate()
        +to_dict()
    }

    QwenTrainer --> PromptCreator: uses
    QwenTrainer --> Callbacks: manages
    TeacherSynthesisFramework --> PromptCreator: uses
    TeacherSynthesisFramework --> ModelConfig: configures
    Callbacks --> QwenTrainer: monitors

    note for QwenTrainer "Main training orchestrator"
    note for PromptCreator "Handles prompt generation"
    note for TeacherSynthesisFramework "Manages synthetic data"
    note for Callbacks "Monitors training process"
Loading

Monitoring System

graph LR
    classDef metrics fill:#e1f5fe,stroke:#0277bd,stroke-width:2px
    classDef callbacks fill:#fff3e0,stroke:#ef6c00,stroke-width:2px
    classDef viz fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px

    subgraph Metrics["Metrics Collection"]
        M1["Training Loss<br/>(Per Step)"]:::metrics --> M4["WandB<br/>Logger"]:::metrics
        M2["Validation<br/>Metrics"]:::metrics --> M4
        M3["Prompt<br/>Quality"]:::metrics --> M4
    end

    subgraph Callbacks["Training Control"]
        C1["Early<br/>Stopping"]:::callbacks --> C4["Training<br/>Control"]:::callbacks
        C2["Learning Rate<br/>Monitor"]:::callbacks --> C4
        C3["Prompt<br/>Monitor"]:::callbacks --> C4
    end

    subgraph Visualization["Analytics Dashboard"]
        V1["Loss<br/>Curves"]:::viz --> V4["WandB<br/>Dashboard"]:::viz
        V2["Prompt<br/>Statistics"]:::viz --> V4
        V3["Model<br/>Performance"]:::viz --> V4
    end

    Metrics --> Callbacks
    Callbacks --> Visualization

    style Metrics fill:#f8f9fa,stroke:#343a40,stroke-width:2px
    style Callbacks fill:#f8f9fa,stroke:#343a40,stroke-width:2px
    style Visualization fill:#f8f9fa,stroke:#343a40,stroke-width:2px
Loading

About

Knowledge Distillation About YAML-Based Structured Multi-Step Reasoning from a Teacher Model GPT-4o to a Small LLM: Qwen2.5-Coder-1.5B-Instruct

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published