Skip to content

thedatamonk/tune-llm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Cool, let’s turn this into a concrete game plan for:

Base model: HuggingFaceTB/SmolLM-135M-Instruct Dataset: keivalya/MedQuad-MedicalQnADataset (or a trimmed variant)

No code, just what you’ll do at each step.


Phase 0 – Define constraints

  • Goal: make a tiny medical Q&A assistant that:

    • Answers medical questions more accurately.
    • Still writes normal English.
    • Is safe (no wild dosing advice etc.).
  • Hardware assumption: 1 GPU with ~8–16 GB VRAM or a beefy CPU → use LoRA instead of full fine-tune.


Phase 1 – Understand model & dataset

  1. Model card (SmolLM-135M-Instruct):

    • Confirm: license (Apache 2.0), context length, chat format / expected prompts. (Hugging Face)
    • Note its “instruction” style (it’s already fine-tuned for following instructions).
  2. Dataset card (keivalya/MedQuad-MedicalQnADataset or mlabonne/MedQuad-MedicalQnADataset):

    • Inspect columns: usually something like question, answer (names may vary). (Hugging Face)
    • Check size (~10k–100k samples).
    • Optionally consider a length-trimmed version like Laurent1/MedQuad-MedicalQnADataset_128tokens_max if you want shorter, cheaper training. (Hugging Face)

Outcome: you know exactly which columns represent the user’s question and the expected answer.


Phase 2 – Decide your training format (how examples look)

You need to map each row into a single prompt + target answer.

For each MedQuad example:

  • User side (input to model): Something like:

    • System-ish instruction: “You are a safe, helpful medical assistant…”
    • User question: the medical question from dataset.
  • Assistant side (label to learn):

    • The ground truth medical answer from dataset.

You’ll:

  • Adopt one consistent chat template for all examples.
  • Ensure safety tone in system prompt (e.g., “do not give definitive treatment decisions; encourage consulting a doctor”).

Phase 3 – Clean & split the data

Actions:

  1. Filter noisy stuff:

    • Drop examples with missing/empty question or answer.
    • Optionally drop extremely long ones (e.g., > N characters/tokens).
  2. Train / validation split:

    • ~80–90% → train.
    • ~10–20% → validation (never used for training).
  3. (Optional but useful) Create a small “eval prompts” file by hand:

    • 20–50 medical questions that you care about (lab report interpretation, symptoms, edge cases).
    • You’ll use these later for before/after comparison.

Phase 4 – Choose fine-tuning strategy (LoRA / PEFT)

You’ll go with:

  • PEFT / LoRA on top of SmolLM-135M-Instruct:

    • Base model weights are frozen.
    • You train tiny adapter layers → light on memory, fast, reversible.

Decisions you’ll make (conceptually, no numbers yet is fine):

  • Which layers to attach LoRA to (usually attention / feedforward modules).
  • How many adapter dimensions (rank r).
  • Whether to train embeddings / layernorms (likely no for a start).

Phase 5 – Set basic training knobs

You’ll configure (at a high level):

  • Max sequence length:

    • Big enough for question + answer; with MedQuad often 256–1024 tokens is okay.
  • Batch size (effective):

    • Chosen based on VRAM; you can simulate big batch via gradient accumulation.
  • Learning rate:

    • Small enough to not destroy the base model’s language knowledge.
  • Epochs:

    • Start with 1–3 epochs over MedQuad.
  • Logging / checkpoints:

    • Save model every X steps.
    • Track training loss and validation loss.

Core idea: You’ll watch validation loss to decide if you’re overfitting (loss starts going up).


Phase 6 – Run a tiny sanity-check training

Before full training, you’ll:

  1. Train on a very small subset (e.g., 100–500 examples) for a few hundred steps.

  2. Check:

    • Does training loss go down?
    • Can the model now memorize a few training examples (it should reproduce them roughly)?
    • Does it still respond sensibly to a non-medical question?

If something is broken (wrong formatting, wrong labels, etc.), you catch it here instead of after a 3-hour run.


Phase 7 – Full training run

Once the sanity check looks good:

  1. Train on full MedQuad train split with chosen hyperparams.

  2. Monitor:

    • Training loss.
    • Validation loss.
  3. Possibly early stop if validation loss stops improving.

Result: a set of LoRA adapter weights that specialize SmolLM on MedQuad.


Phase 8 – Evaluate: “Did it actually become more medical (and safe)?”

You’ll evaluate on three kinds of prompts.

  1. In-domain medical questions:

    • Use held-out MedQuad validation questions.
    • Compare answers vs ground truth.
    • Optionally score correctness manually for a sample.
  2. Out-of-domain / general questions:

    • Ask about random non-medical topics.
    • Make sure model hasn’t forgotten basic language ability.
  3. Safety probes:

    • Risky questions: dosage, self-surgery, ignoring doctors, etc.

    • Check that the model:

      • Avoids giving specific dangerous instructions.
      • Encourages seeing a professional.
      • Uses cautious language.

If safety is weak:

  • Add extra synthetic examples with “refusal + safe guidance” responses.
  • Do another short SFT pass mainly on safety data.

Phase 9 – Save, package, and (optionally) upload

You’ll:

  1. Save:

    • The base model reference (SmolLM-135M-Instruct).
    • The LoRA adapter weights.
    • The training config (hyperparams, dataset used, commit hash).
  2. For reuse:

    • Load base + LoRA for inference, or
    • (Optionally) merge LoRA into the base model and save a single merged checkpoint.
  3. (Optional but nice) Push to Hugging Face:

    • Create a repo like yourname/SmolLM-135M-MedQuad-LoRA.

    • Document:

      • That it’s not a doctor.
      • That it was trained on MedQuad.
      • Intended use: research / experimentation, not clinical decisions.

Phase 10 – Inference & iteration loop

Finally:

  • Wire up a simple chat endpoint (FastAPI/Flask) that:

    • Wraps prompts with your chat template.
    • Applies system prompt (safe medical assistant).
    • Calls the fine-tuned model and streams responses.
  • Use it interactively:

    • Note failure modes (hallucinations, unsafe suggestions).
    • Collect “bad” outputs → turn them into new training examples.
    • Do small, focused SFT passes to gradually improve behavior.

If you want, next step I can:

  • Turn this plan into exact libraries + components (e.g., “use transformers + trl + peft with SFTTrainer”), still concept-level, or
  • Jump straight to actual training code for this SmolLM + MedQuad combo.

About

My experiments with LLM finetuning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published