Skip to content
/ HEAL Public

β€οΈβ€πŸ©Ή HEAL: Harmonizing Efficient Alignment with RLAIF and RLHF in Health AI

License

Notifications You must be signed in to change notification settings

seonokkim/HEAL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

10 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

β€οΈβ€πŸ©Ή HEAL: Harmonizing Efficient Alignment with RLAIF and RLHF for Health AI

Project Overview

HEAL is a a project-in-progress designed to explore and implement reinforcement learning techniquesβ€”Reinforcement Learning from Human Feedback (RLHF) and Reinforcement Learning from AI Feedback (RLAIF)β€”for aligning Large Language Models (LLMs) in healthcare applications.

This repository includes tools for:

  • Data preprocessing for healthcare datasets like MIMIC-IV and PubMedQA.
  • Supervised fine-tuning (SFT) of policy models.
  • Reward model training for RLHF and RLAIF.
  • Reinforcement Learning using PPO.
  • Evaluation of model robustness and adversarial testing (e.g., Jailbreaks).

Repository Structure

HEAL/
β”œβ”€β”€ data/                     # Data directory
β”‚   β”œβ”€β”€ processed/            # Preprocessed datasets
β”‚   β”œβ”€β”€ raw/                  # Raw datasets
β”œβ”€β”€ experiments/              # Experiment configurations, logs, and results
β”‚   β”œβ”€β”€ configs/              # Training configurations
β”‚   β”œβ”€β”€ logs/                 # Training logs
β”‚   └── results/              # Evaluation results
β”œβ”€β”€ models/                   # Trained models
β”‚   β”œβ”€β”€ policy_model/         # Fine-tuned policy models
β”‚   β”œβ”€β”€ reward_model/         # Trained reward models
β”œβ”€β”€ src/                      # Source code
β”‚   β”œβ”€β”€ data_preprocessing/   # Data preprocessing scripts
β”‚   β”œβ”€β”€ evaluation/           # Evaluation scripts
β”‚   β”œβ”€β”€ training/             # Training scripts (SFT, PPO, reward model training)
β”‚   β”œβ”€β”€ utils/                # Helper functions
β”œβ”€β”€ environment.yml           # Conda environment setup
β”œβ”€β”€ requirements.txt          # List of dependencies
└── README.md                 # Overview of the project

Installation

1. Clone the Repository

git clone https://github.com/seonokrkim/HEAL.git
cd HEAL

2. Install Dependencies

Using pip:

pip install -r requirements.txt

Using Conda (optional):

conda env create -f environment.yml
conda activate heal

Usage

1. Data Preprocessing

Prepare datasets like MIMIC-IV (summarization) and PubMedQA (question-answering).

# Preprocess MIMIC-IV data
python src/data_preprocessing/preprocess_data.py --dataset mimic \
    --input_path data/raw/mimic_iv_notes.csv \
    --output_path data/processed/mimic_iv_summaries.csv

# Preprocess PubMedQA data
python src/data_preprocessing/preprocess_data.py --dataset pubmedqa \
    --input_path data/raw/pubmedqa.json \
    --output_path data/processed/pubmedqa.csv

2. Supervised Fine-Tuning (SFT)

Fine-tune a pretrained LLM (e.g., GPT-4) on a processed dataset.

# Fine-tune on MIMIC-IV
python src/training/sft.py --dataset mimic \
    --dataset_path data/processed/mimic_iv_summaries.csv \
    --model_name gpt-4 \
    --output_dir models/policy_model/mimic_sft

# Fine-tune on PubMedQA
python src/training/sft.py --dataset pubmedqa \
    --dataset_path data/processed/pubmedqa.csv \
    --model_name gpt-4 \
    --output_dir models/policy_model/pubmedqa_sft

3. Train Reward Models

Train a reward model for RLHF or RLAIF workflows.

python src/training/reward_model.py --dataset pubmedqa \
    --dataset_path data/processed/pubmedqa.csv \
    --model_name gpt-4 \
    --output_dir models/reward_model/pubmedqa

4. Reinforcement Learning with PPO

Perform policy optimization using a reward model.

python src/training/ppo.py --model_name gpt-4 \
    --reward_model_path models/reward_model/pubmedqa \
    --output_dir models/policy_model/pubmedqa_ppo

5. Evaluation

Evaluate Model Performance

python src/evaluation/evaluate.py --model_path models/policy_model/pubmedqa_sft \
    --dataset_path data/processed/pubmedqa.csv

Adversarial Testing (e.g., Jailbreaks)

python src/evaluation/adversarial.py --model_path models/policy_model/pubmedqa_sft \
    --adversarial_prompts_path data/raw/adversarial_prompts.csv

Dependencies

Install the following Python packages:

  • transformers
  • trl
  • torch
  • datasets
  • pandas
  • numpy
  • scikit-learn
  • matplotlib

All dependencies are listed in requirements.txt.

License

This project is licensed under the MIT License. See the LICENSE file for details.

Work in Progress 🚧

This repository is under active development, with regular updates expected. Stay tuned for new features and improvements!

About

β€οΈβ€πŸ©Ή HEAL: Harmonizing Efficient Alignment with RLAIF and RLHF in Health AI

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages