Skip to content

TheDeadcoder/medical-cot-assistant

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Clinical Chain-of-Thought Medical Assistant

Unsloth + Hugging Face

A production-grade fine-tuning pipeline for medical reasoning using SFT, QAT, and LLM-as-a-Judge evaluation

Python 3.10+ PyTorch 2.8 ROCm 6.4 License


Demo

Clinical CoT Medical Assistant Demo

Overview

This project fine-tunes a 20B parameter language model on curated medical reasoning datasets to create a clinical assistant capable of chain-of-thought (CoT) reasoning. The pipeline includes:

  1. Dataset Preparation — Merging and formatting multiple medical QA datasets
  2. Supervised Fine-Tuning (SFT) — LoRA-based fine-tuning with Unsloth
  3. Quantization-Aware Training (QAT) — INT4 weight quantization for efficient deployment
  4. GGUF Export — Quantized model export for inference engines (llama.cpp, Ollama)
  5. LLM-as-a-Judge Evaluation — Blind comparative evaluation against baseline and proprietary models

Architecture

┌─────────────────────────────────────────────────────────────────────────────┐
│                           Training Pipeline                                 │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌──────────────┐    ┌──────────────┐    ┌──────────────┐    ┌───────────┐  │
│  │   Medical    │    │     SFT      │    │     QAT      │    │   GGUF    │  │
│  │   Datasets   │───▶│  Fine-Tune   │───▶│  Refinement  │───▶│  Export   │  │
│  │  (3 sources) │    │   (LoRA)     │    │   (INT4)     │    │           │  │
│  └──────────────┘    └──────────────┘    └──────────────┘    └───────────┘  │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
                                    │
                                    ▼
┌─────────────────────────────────────────────────────────────────────────────┐
│                         Evaluation Pipeline                                 │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌──────────────┐    ┌──────────────┐    ┌──────────────┐    ┌───────────┐  │
│  │  Test Set    │    │   Generate   │    │  LLM Judge   │    │   Win     │  │
│  │              │───▶│  Responses   │───▶│  (GPT-5.2)   │───▶│   Rates   │  │
│  │              │    │  (4 models)  │    │  Blind Eval  │    │           │  │
│  └──────────────┘    └──────────────┘    └──────────────┘    └───────────┘  │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Datasets

Dataset Source Description
medical-o1-reasoning-SFT FreedomIntelligence Complex chain-of-thought medical reasoning
Medical-R1-Distill-Data FreedomIntelligence Distilled medical reasoning data
MedReason UCSC-VLAA Medical QA with structured reasoning

All datasets are normalized to a unified schema:

instruction → reasoning → output

The reasoning component is wrapped in <think>...</think> tags following the model's native reasoning format.

Project Structure

medical_assistant/
├── configs/
│   ├── gpt_oss_20b_sft.yaml      # SFT configuration
│   └── gpt_oss_20b_qat.yaml      # QAT configuration
├── data/
│   └── prepare_gpt_oss_sft_dataset.py
├── train/
│   └── sft_gpt_oss_20b_unsloth.py
├── scripts/
│   └── quantize.py
├── eval/
│   └── llm_judge_eval.py
├── utils/
│   └── seed.py
├── outputs/                       # Model checkpoints & artifacts
├── .env                          # API keys (gitignored)
└── README.md

Requirements

Used Hardware

  • GPU: AMD MI300X (ROCm 6.4.0) in DigitalOcean
  • RAM: 128GB+ recommended
  • Storage: 200GB+ for model checkpoints

Software

  • Ubuntu 24.04 LTS
  • Python 3.10+
  • ROCm 6.4.0

Installation

1. System Dependencies

sudo apt update && sudo apt upgrade -y
sudo apt install -y build-essential git wget curl vim tmux htop nvtop
sudo apt-get install -y python3-venv

2. Python Environment

python3 -m venv .venv
source .venv/bin/activate

3. PyTorch & Unsloth (AMD ROCm)

pip install --upgrade torch==2.8.0 pytorch-triton-rocm torchvision torchaudio torchao==0.13.0 xformers \
  --index-url https://download.pytorch.org/whl/rocm6.4

pip install --no-deps unsloth unsloth-zoo
pip install --no-deps git+https://github.com/unslothai/unsloth-zoo.git
pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth"

4. Additional Dependencies

pip install -U transformers accelerate datasets trl peft wandb evaluate omegaconf python-dotenv rich safetensors sentencepiece

Note: If dependency conflicts occur:

pip install trl==0.24.0 datasets==4.3.11 msgspec cut_cross_entropy

5. Environment Variables

Create a .env file in the project root:

WANDB_PROJECT=clinical-cot
WANDB_ENTITY=your-wandb-username
WANDB_API_KEY=your-wandb-api-key
HF_TOKEN=your-huggingface-token
OPENAI_API_KEY=your-openai-api-key

6. Build llama.cpp (Required for GGUF Export)

The GGUF quantization step requires llama.cpp to be built locally. Unsloth uses it internally for model conversion.

# Install CMake if not present
sudo apt-get install -y cmake

# Clone llama.cpp
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp

# Build with CMake (disable CURL if no internet access)
mkdir -p build
cd build
cmake .. -DLLAMA_CURL=OFF
cmake --build . --config Release -j$(nproc)

# Copy the quantizer binary to where Unsloth expects it
cp ~/workspace/llama.cpp/build/bin/llama-quantize ~/workspace/llama.cpp/

# Convert the Hugging Face model to unquantized GGUF
python3 convert_hf_to_gguf.py ../outputs/gpt_oss_20b/merged/gpt-oss-20b_clinical-cot_qat_refined --outfile ../outputs/gpt_oss_20b/gguf/model-f16.gguf --outtype f16

# Quantize to my desired methods:
build/bin/llama-quantize ../outputs/gpt_oss_20b/gguf/model-f16.gguf ../outputs/gpt_oss_20b/gguf/model-q4_k_m.gguf q4_k_m
build/bin/llama-quantize ../outputs/gpt_oss_20b/gguf/model-f16.gguf ../outputs/gpt_oss_20b/gguf/model-q5_k_m.gguf q5_k_m
build/bin/llama-quantize ../outputs/gpt_oss_20b/gguf/model-f16.gguf ../outputs/gpt_oss_20b/gguf/model-q8_0.gguf q8_0

Note: The llama.cpp project no longer supports the legacy make build system. You must use CMake.

Troubleshooting: If you encounter "No working quantizer found" errors, ensure llama-quantize exists in the llama.cpp/ directory (not in llama.cpp/build/bin/).

Quantization size comparison

Quantization Method Original Size (MiB) Quantized Size (MiB) Compression Ratio Space Saved (MiB)
q4_k_m 39909.25 15060.55 37.7% 24848.70
q5_k_m 39909.25 16098.07 40.4% 23811.18
q8_0 39909.25 21218.21 53.2% 18691.04

Usage

Step 1: Prepare Dataset

python data/prepare_gpt_oss_sft_dataset.py

This script:

  • Downloads and merges the three medical datasets
  • Normalizes schema to instruction, reasoning, output
  • Applies chat template formatting with <think> tags
  • Creates train/eval splits (95%/5%)
  • Saves processed dataset to data/gpt_oss/sft_dataset/

Step 2: Supervised Fine-Tuning (SFT)

python train/sft_gpt_oss_20b_unsloth.py --config configs/gpt_oss_20b_sft.yaml

Key Configuration (configs/gpt_oss_20b_sft.yaml):

Parameter Value Description
model.base_model_name unsloth/gpt-oss-20b-BF16 Base model
lora.r 32 LoRA rank
lora.lora_alpha 64 LoRA scaling factor
train.learning_rate 1e-4 Peak learning rate
train.num_train_epochs 2 Training epochs
train.packing true Sequence packing for efficiency

Outputs:

  • LoRA adapter: outputs/gpt_oss_20b/runs/<run_name>/lora_adapter/
  • Merged model: outputs/gpt_oss_20b/merged/<run_name>/

Step 3: Quantization-Aware Training (QAT)

python train/sft_gpt_oss_20b_unsloth.py --config configs/gpt_oss_20b_qat.yaml

Key Configuration (configs/gpt_oss_20b_qat.yaml):

Parameter Value Description
model.base_model_name outputs/.../sft_base SFT checkpoint
qat.enabled true Enable QAT mode
qat.qat_scheme int4_weight_only Quantization scheme
qat.learning_rate 5e-5 Lower LR for refinement
train.num_train_epochs 1 QAT converges quickly

Step 4: Export to GGUF

python scripts/quantize.py

Exports the QAT model to multiple GGUF quantization formats:

  • q4_k_m — 4-bit (recommended for deployment)
  • q5_k_m — 5-bit (balanced quality/size)
  • q8_0 — 8-bit (highest quality)

Output: outputs/gpt_oss_20b/gguf/

Step 5: Evaluation (LLM-as-a-Judge)

python eval/llm_judge_eval.py

Evaluation Protocol:

  1. Sample 50 questions from the held-out eval set
  2. Generate responses from 4 models:
    • Fine-tuned BF16 model
    • QAT model
    • Base model (unsloth/gpt-oss-20b-BF16)
    • GPT-4.1 (proprietary baseline)
  3. Anonymize and shuffle responses
  4. GPT-5.2 ranks all 4 responses per question (with respect to the ground truth)
  5. Aggregate win rates

Output: outputs/evaluation_results.json

Example Results (50-question LLM-as-a-Judge eval)

Win Rates (out of 50 prompts)

Model Wins Win Rate
Base GPT-OSS (Unfinetuned) 1 2.0%
GPT-4.1 (OpenAI API) 0 0.0%
QAT Fine-tuned GPT-OSS 14 28.0%
Fine-tuned GPT-OSS (BF16) 35 70.0%

Average Overall Scores (1–10, GPT-5.2 judge)

Model Score
Base GPT-OSS (Unfinetuned) 7.12
GPT-4.1 (OpenAI API) 7.37
QAT Fine-tuned GPT-OSS 8.32
Fine-tuned GPT-OSS (BF16) 9.16

Latency (per question, approx.)

Model Time (s)
Base GPT-OSS (Unfinetuned) 14.3
Fine-tuned GPT-OSS (BF16) 14.7
QAT Fine-tuned GPT-OSS 13.9
GPT-4.1 (OpenAI API) 5.8

These results are from a 50-question sample and a single GPT-5.2 judge; they should be treated as preliminary.

Configuration Reference

SFT Configuration Schema

run:
  seed: 42                    # Reproducibility seed
  output_dir: outputs/...     # Checkpoint directory
  run_name: ...               # W&B run name

data:
  input_disk_path: ...        # Pre-merged dataset (optional)
  output_disk_path: ...       # Processed dataset output
  eval_size: 0.05             # Eval split ratio
  max_seq_length: 4096        # Maximum sequence length
  reasoning_effort: medium    # Reasoning verbosity hint

model:
  base_model_name: ...        # HuggingFace model ID or local path
  load_in_4bit: false         # 4-bit loading (inference only)
  dtype: bf16                 # Model dtype

lora:
  r: 32                       # LoRA rank
  lora_alpha: 64              # LoRA alpha
  lora_dropout: 0.05          # Dropout rate
  target_modules: [...]       # Modules to adapt

qat:
  enabled: false              # Enable QAT mode
  qat_scheme: int4_weight_only
  learning_rate: 5.0e-5

train:
  per_device_train_batch_size: 8
  gradient_accumulation_steps: 8
  learning_rate: 1.0e-4
  num_train_epochs: 2
  warmup_ratio: 0.03
  lr_scheduler_type: cosine
  packing: true
  report_to: wandb

Experiment Tracking

All training runs are logged to Weights & Biases:

  • Loss curves (train/eval)
  • Learning rate schedule
  • Gradient norms
  • Hyperparameters
  • System metrics (GPU utilization, memory)

Model Outputs

Artifact Path Description
SFT Merged Model outputs/gpt_oss_20b/merged/..._sft_base/ Full merged BF16 model
QAT Merged Model outputs/gpt_oss_20b/merged/..._qat/ QAT-refined model
GGUF Quantized outputs/gpt_oss_20b/gguf/ Deployment-ready quantized models
Eval Results outputs/evaluation_results.json LLM judge rankings

Deployment

With Ollama

ollama create clinical-cot -f outputs/gpt_oss_20b/gguf/model-q4_k_m.gguf
ollama run clinical-cot

With llama.cpp

./main -m outputs/gpt_oss_20b/gguf/model-q4_k_m.gguf \
  -p "What are the differential diagnoses for chest pain?" \
  -n 512

Inference Server & Web UI

This project includes a production-ready inference server and a Streamlit-based web interface for interacting with the fine-tuned model.

Clinical CoT Medical Assistant Demo

Architecture

┌─────────────────────────────────────────────────────────────────────────────┐
│                         Inference Stack                                     │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  ┌──────────────┐    ┌──────────────┐    ┌──────────────┐                   │
│  │   Streamlit  │    │    FastAPI   │    │  llama.cpp   │                   │
│  │    Web UI    │───▶│    Server    │───▶│   Backend    │                   │
│  │  (app.py)    │    │ (server.py)  │    │  (GGUF Model)│                   │
│  └──────────────┘    └──────────────┘    └──────────────┘                   │
│       :8501              :8000                                              │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘

Citation

If you use this pipeline, please cite the source datasets:

@misc{medical-o1-reasoning,
  title={Medical-O1-Reasoning-SFT},
  author={FreedomIntelligence},
  year={2024},
  publisher={HuggingFace}
}

@misc{medical-r1-distill,
  title={Medical-R1-Distill-Data},
  author={FreedomIntelligence},
  year={2024},
  publisher={HuggingFace}
}

@misc{medreason,
  title={MedReason},
  author={UCSC-VLAA},
  year={2024},
  publisher={HuggingFace}
}

License

This project is for research and educational purposes. Please ensure compliance with:

  • Base model license terms
  • Dataset licenses
  • OpenAI API terms of service (for evaluation)

Built with Unsloth 🦥 + Hugging Face 🤗

About

Supervised fine-tuning pipeline for medical reasoning using SFT, QAT, and LLM-as-a-Judge evaluation

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages