Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/hurtful_word_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pyhealth.datasets import MIMIC3Dataset, split_by_patient, get_dataloader
from pyhealth.models import ClinicalBERTWrapper
from pyhealth.tasks import HurtfulWordsBiasTask
from pyhealth.trainer import Trainer

# STEP 1: load MIMIC-III
base = MIMIC3Dataset(
root="/srv/local/data/physionet.org/files/mimiciii/1.4",
tables=["NOTEEVENTS", "PATIENTS"],
dev=False,
refresh_cache=False
)

# STEP 2: set our bias task
bias_task = HurtfulWordsBiasTask(positive_group="female", negative_group="male")
task_dataset = base.set_task(bias_task)
task_dataset.stat()

# STEP 3: train/test split & dataloaders
train_ds, val_ds, test_ds = split_by_patient(task_dataset, [0.8, 0.1, 0.1])
train_dl = get_dataloader(train_ds, batch_size=16, shuffle=True)
val_dl = get_dataloader(val_ds, batch_size=16, shuffle=False)
test_dl = get_dataloader(test_ds, batch_size=16, shuffle=False)

# STEP 4: wrap a ClinicalBERT model
model = ClinicalBERTWrapper(
pretrained_model_name="emilyalsentzer/Bio_ClinicalBERT",
device="cuda"
)

# STEP 5: train/calibrate if needed
trainer = Trainer(model=model, task=bias_task)
trainer.train(train_dl, val_dl, epochs=1, monitor=None)

# STEP 6: evaluate log-bias and precision_gap
metrics = ["log_bias", "precision_gap"]
results = trainer.evaluate(test_dl, metrics=metrics)
print("Fairness results:", results)
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@
)
from .sleep_staging_v2 import SleepStagingSleepEDF
from .temple_university_EEG_tasks import EEG_events_fn, EEG_isAbnormal_fn
from .hurtful_words_bias import HurtfulWordsBiasTask
101 changes: 101 additions & 0 deletions pyhealth/tasks/hurtful_words_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# =============================================================================
# Ritul Soni (rsoni27)
# “Hurtful Words” Bias Quantification
# Paper: Hurtful Words in Clinical Contextualized Embeddings
# Link: https://arxiv.org/abs/2012.00355
#
# Implements:
# - log probability bias score per [Zhang et al., 2020]
# - precision gap as an additional fairness metric
# =============================================================================

from typing import List, Tuple, Dict
import numpy as np
from pyhealth.tasks.base import BaseTask

class HurtfulWordsBiasTask(BaseTask):
"""Compute log-probability bias and precision-gap on ClinicalBERT outputs.

Will be called in `dataset.set_task(hurtful_words_bias_fn)`.
"""

def __init__(self, positive_group: str = "female", negative_group: str = "male"):
"""
Args:
positive_group (str): demographic label for privileged group.
negative_group (str): demographic label for unprivileged group.
"""
super().__init__()
self.positive = positive_group
self.negative = negative_group

def get_ground_truth(self, patient_record: Dict) -> str:
"""Extract demographic label from the record.

Args:
patient_record: a dict containing at least 'gender'.

Returns:
str: either self.positive or self.negative.
"""
gender = patient_record["gender"].lower()
return self.positive if gender == self.positive else self.negative

def get_prediction(self, model, text: str) -> float:
"""Mask target word in `text`, compute its log-probability under `model`.

Args:
model: a HuggingFace MaskedLM
text (str): one clinical note with a single [MASK]

Returns:
float: log P(target_token | context)
"""
# your helper logic here...
return model.get_log_prob(text)

def evaluate(self,
data: List[Dict],
model,
metrics: List[str] = ["log_bias", "precision_gap"]
) -> Dict[str, float]:
"""
Compute requested metrics over the test split.

Args:
data (List[Dict]): list of records with 'text' and 'gender'
model: a calibrated or uncalibrated ClinicalBERT wrapper
metrics (List[str]): which metrics to compute

Returns:
Dict[str, float]: metric_name → value
"""
# collect scores and labels
scores, labels = [], []
for rec in data:
scores.append(self.get_prediction(model, rec["text"]))
labels.append(self.get_ground_truth(rec))
scores = np.array(scores)
labels = np.array(labels)

results = {}
if "log_bias" in metrics:
priv = scores[labels == self.positive].mean()
unpriv = scores[labels == self.negative].mean()
results["log_bias"] = priv - unpriv

if "precision_gap" in metrics:
# threshold at median score
thresh = np.median(scores)
preds = scores >= thresh
def precision(y_true, y_pred, grp):
mask = (labels == grp)
tp = np.sum((y_true[mask] == 1) & (y_pred[mask] == 1))
fp = np.sum((y_true[mask] == 0) & (y_pred[mask] == 1))
return tp / (tp + fp + 1e-12)
# map gender to binary y_true: privileged=1, unprivileged=0
y_true = (labels == self.positive).astype(int)
results["precision_gap"] = precision(y_true, preds, self.positive) - \
precision(y_true, preds, self.negative)

return results
72 changes: 72 additions & 0 deletions pyhealth/unittests/test_hurtful_words_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# =============================================================================
# Tests for HurtfulWordsBiasTask
# Author: Ritul Soni (rsoni27)
# Description: Unit tests for log_bias and precision_gap metrics of
# HurtfulWordsBiasTask in PyHealth.
# =============================================================================

import pytest
import numpy as np
from pyhealth.tasks.hurtful_words_bias import HurtfulWordsBiasTask


class DummyModel:
"""
Dummy model that returns predetermined log-probability scores.
"""
def __init__(self, scores):
self.scores = scores
self.idx = 0

def get_log_prob(self, text):
# Return the next score in the list
val = self.scores[self.idx]
self.idx += 1
return val


def test_log_bias_and_precision_gap():
# Prepare synthetic data
genders = ["female", "female", "male", "male"]
# Scores: female->[3.0,1.0], male->[2.0,0.0]
scores = [3.0, 1.0, 2.0, 0.0]

data = [{"text": "", "gender": g} for g in genders]
model = DummyModel(scores)

task = HurtfulWordsBiasTask(positive_group="female", negative_group="male")
results = task.evaluate(data, model, metrics=["log_bias", "precision_gap"])

# log_bias = mean(female)-mean(male) = (3+1)/2 - (2+0)/2 = 2 - 1 = 1
assert pytest.approx(results["log_bias"], rel=1e-6) == 1.0

# precision_gap = 1.0 (privileged precision 1.0 vs unprivileged 0.0)
assert pytest.approx(results["precision_gap"], rel=1e-6) == 1.0


def test_empty_data():
# Edge case: no data
data = []
model = DummyModel([])
task = HurtfulWordsBiasTask()

# Should return empty dict or zeros without raising error
results = task.evaluate(data, model, metrics=["log_bias", "precision_gap"])
assert isinstance(results, dict)
assert results.get("log_bias", 0) == 0 or results.get("log_bias") is None
assert results.get("precision_gap", 0) == 0 or results.get("precision_gap") is None


def test_single_group_data():
# Edge case: all records belong to positive_group
genders = ["female", "female"]
scores = [0.5, 0.7]
data = [{"text": "", "gender": g} for g in genders]
model = DummyModel(scores)

task = HurtfulWordsBiasTask(positive_group="female", negative_group="male")
results = task.evaluate(data, model, metrics=["precision_gap"])

# Unprivileged group missing; precision_gap should be computed as difference with zero or None
# privileged precision = 1.0 (all predicted positive), unprivileged = 0.0
assert pytest.approx(results["precision_gap"], rel=1e-6) == pytest.approx(1.0)