HEAL is a a project-in-progress designed to explore and implement reinforcement learning techniquesβReinforcement Learning from Human Feedback (RLHF) and Reinforcement Learning from AI Feedback (RLAIF)βfor aligning Large Language Models (LLMs) in healthcare applications.
This repository includes tools for:
- Data preprocessing for healthcare datasets like MIMIC-IV and PubMedQA.
- Supervised fine-tuning (SFT) of policy models.
- Reward model training for RLHF and RLAIF.
- Reinforcement Learning using PPO.
- Evaluation of model robustness and adversarial testing (e.g., Jailbreaks).
HEAL/
βββ data/ # Data directory
β βββ processed/ # Preprocessed datasets
β βββ raw/ # Raw datasets
βββ experiments/ # Experiment configurations, logs, and results
β βββ configs/ # Training configurations
β βββ logs/ # Training logs
β βββ results/ # Evaluation results
βββ models/ # Trained models
β βββ policy_model/ # Fine-tuned policy models
β βββ reward_model/ # Trained reward models
βββ src/ # Source code
β βββ data_preprocessing/ # Data preprocessing scripts
β βββ evaluation/ # Evaluation scripts
β βββ training/ # Training scripts (SFT, PPO, reward model training)
β βββ utils/ # Helper functions
βββ environment.yml # Conda environment setup
βββ requirements.txt # List of dependencies
βββ README.md # Overview of the project
git clone https://github.com/seonokrkim/HEAL.git
cd HEAL
Using pip:
pip install -r requirements.txt
Using Conda (optional):
conda env create -f environment.yml
conda activate heal
Prepare datasets like MIMIC-IV (summarization) and PubMedQA (question-answering).
# Preprocess MIMIC-IV data
python src/data_preprocessing/preprocess_data.py --dataset mimic \
--input_path data/raw/mimic_iv_notes.csv \
--output_path data/processed/mimic_iv_summaries.csv
# Preprocess PubMedQA data
python src/data_preprocessing/preprocess_data.py --dataset pubmedqa \
--input_path data/raw/pubmedqa.json \
--output_path data/processed/pubmedqa.csv
Fine-tune a pretrained LLM (e.g., GPT-4) on a processed dataset.
# Fine-tune on MIMIC-IV
python src/training/sft.py --dataset mimic \
--dataset_path data/processed/mimic_iv_summaries.csv \
--model_name gpt-4 \
--output_dir models/policy_model/mimic_sft
# Fine-tune on PubMedQA
python src/training/sft.py --dataset pubmedqa \
--dataset_path data/processed/pubmedqa.csv \
--model_name gpt-4 \
--output_dir models/policy_model/pubmedqa_sft
Train a reward model for RLHF or RLAIF workflows.
python src/training/reward_model.py --dataset pubmedqa \
--dataset_path data/processed/pubmedqa.csv \
--model_name gpt-4 \
--output_dir models/reward_model/pubmedqa
Perform policy optimization using a reward model.
python src/training/ppo.py --model_name gpt-4 \
--reward_model_path models/reward_model/pubmedqa \
--output_dir models/policy_model/pubmedqa_ppo
python src/evaluation/evaluate.py --model_path models/policy_model/pubmedqa_sft \
--dataset_path data/processed/pubmedqa.csv
python src/evaluation/adversarial.py --model_path models/policy_model/pubmedqa_sft \
--adversarial_prompts_path data/raw/adversarial_prompts.csv
Install the following Python packages:
transformers
trl
torch
datasets
pandas
numpy
scikit-learn
matplotlib
All dependencies are listed in requirements.txt
.
This project is licensed under the MIT License. See the LICENSE file for details.
This repository is under active development, with regular updates expected. Stay tuned for new features and improvements!