Skip to content

Welcome to TPT, a framework for teaching large language models to solve math problems by learning from (and improving on) their own reasoning traces.

Notifications You must be signed in to change notification settings

ScalingIntelligence/TPT

Repository files navigation

TPT Project

Welcome to TPT – Think • Prune • Train! A framework for teaching large language models to solve math problems by learning from (and improving on) their own reasoning traces.


🚀 What is TPT?

TPT is a three‑step, iterative workflow:

  1. Think – The model generates multiple, detailed solution traces.
  2. Prune – We automatically keep only the traces that reach the correct answer.
  3. Train – The model fine‑tunes on this high‑quality synthetic data to boost its skills.

Loop the cycle → watch the model level up. ✨


🛠️ Workflow & Commands

Below is the minimal command‑line recipe for each stage. Adjust paths/flags to taste.

1. Think – Generate Synthetic Traces (💡 gen_synth.py)

Produce N solution attempts per question.

python gen_synth.py \
  --model_name    google/gemma-2-2b-it \
  --max_model_len 1500 \
  --num_samples   5 \
  --math          data/gsm8ktrain.json \
  --output_dir    samples/math_train/2b

Outputs go to samples/math_train/ft/e0.json … e5.json.

2. Prune & Split (✂️ evmath.py → 📄 make_json.py)

  1. Score correctness with evmath.py (example):
    python evmath.py --samples_dir samples/math_train/ft --answer_path data gsm8ktrain --num_samples 5
    This writes correct_answers.json and pass_at_k_results.json.
  2. Create new train/eval JSON:
    python make_json.py \
      --input        samples/math_train/correct_answers.json \
      --train_output data/next/train2k.json \
      --eval_output  data/next/evnext.json \
      --train_size   2000

Use the new data in the next TPT cycle (back to Train).

3. Train (🚂 sft_math.py)

Fine‑tune the base model used to generate the data on the created dataset.

python sft_math.py \
  --model_name_or_path google/gemma-2-2b-it \
  --train_data_path data/next/train2k.json \
  --eval_data_path  data/next/evnext.json \
  --learning_rate   1e-6 \
  --output_dir      gemma-tpt

This produces a checkpoint under gemma-tpt/ and logs to W&B (set your project and name inside the script).


📂 Repository Structure

TPT/
├── data/             # Datasets (initial + generated)
├── gemma-tpt/        # Model checkpoints & artifacts
├── samples/          # Synthetic traces
├── wandb/            # Experiment tracking
├── evmath.py         # Scoring / pruning script
├── gen_eval.py       # Generates evaluation questions
├── gen_synth.py      # Synthetic generation script (Think)
├── make_json.py      # Builds new train/eval JSON (Prune)
├── sft_math.py       # Supervised fine‑tune (Train)
├── README.md         # You are here
├── requirements.txt  # Python deps

⚙️ Setup Guide

Prerequisites

  • Python 3.10
  • pip

Installation

git clone <repository-url>
cd <repository-folder>

# Create & activate venv
python3.10 -m venv tpt_env
source tpt_env/bin/activate   # Windows: tpt_env\Scripts\activate

# Install deps
python3.10 -m pip install -r requirements.txt

# Extra: flashinfer wheel (for vLLM‑FlashAttention)
python3.10 -m pip install   https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.3-cp310-cp310-linux_x86_64.whl

Activate later with:

source tpt_env/bin/activate   

Ready? Time to Think → Prune → Train and watch your model improve

About

Welcome to TPT, a framework for teaching large language models to solve math problems by learning from (and improving on) their own reasoning traces.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published