Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions examples/README.markdown
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Chest X-Ray Analysis with PyHealth

## Introduction

This example demonstrates how to use PyHealth to perform chest X-ray classification, focusing on detecting abnormalities such as pneumonia and edema. It builds on the reproducibility efforts for the UniXGen model (Lee et al., 2023), a vision-language generative model for view-specific chest X-ray generation. The example leverages the CheXpert dataset and introduces two new PyHealth contributions:

- A task function (`chest_xray_classification_fn`) to label chest X-rays based on diagnoses.
- A metrics module (`radiographic_agreement`) to evaluate inter-rater agreement for radiographic findings.

## Setup

First, ensure PyHealth and its dependencies are installed. Then, import the required modules and set up logging.

```python
import pyhealth
from pyhealth.datasets import CheXpertDataset
from pyhealth.tasks import chest_xray_classification_fn # Import from pyhealth.tasks
from pyhealth.metrics import radiographic_agreement # Import from pyhealth.metrics
import matplotlib.pyplot as plt
import logging
import pandas as pd
import cv2

# Set up logging
logging.basicConfig(level=logging.INFO)

# Load CheXpert dataset
dataset = CheXpertDataset(
root="/path/to/chexpert", # Update with actual path
dev=False,
refresh_cache=True
)
```

### Notes
- Replace `/path/to/chexpert` with the actual path to your CheXpert dataset directory.
- Ensure the dataset is downloaded and formatted as expected by PyHealth (see [CheXpert documentation](https://stanfordmlgroup.github.io/competitions/chexpert)).

## Data Preprocessing

Use the `chest_xray_classification_fn` task to process the dataset and label X-ray images based on the presence of pneumonia or edema.

```python
samples = []
for patient in dataset.patients:
patient_samples = chest_xray_classification_fn(patient)
samples.extend(patient_samples)

# Convert to DataFrame for analysis
df = pd.DataFrame(samples)
print(df.head())
```

### Expected Output
The `df` DataFrame will contain columns like `patient_id`, `visit_id`, `xray_path`, `view_position`, and `label` (1 if pneumonia or edema is present, 0 otherwise).

## Visualization

Visualize a sample X-ray image along with its label and view position.

```python
sample = df.iloc[0]
img = cv2.imread(sample["xray_path"])
if img is not None:
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title(f"Label: {sample['label']}, View: {sample['view_position']}")
plt.show()
else:
print(f"Failed to load image at {sample['xray_path']}")
```

### Notes
- Ensure `cv2` (OpenCV) is installed to load and display images.
- If the image fails to load, verify the `xray_path` points to a valid file.

## Model Training (Simple Example)

Train a basic PyHealth model (e.g., `LogisticRegression`) to classify X-rays. Note that this is a placeholder; in practice, you’d likely use a deep learning model (e.g., a CNN) to extract features from X-ray images.

```python
from pyhealth.models import LogisticRegression

model = LogisticRegression(
feature_keys=["xray_path"],
label_key="label",
feature_dims=512 # Placeholder for image feature dimension
)
model.fit(dataset, batch_size=32, epochs=5)
```

### To-Do
- Replace `LogisticRegression` with a more suitable model (e.g., a CNN like ResNet) and preprocess X-ray images into feature vectors.
- Adjust `feature_dims` based on your feature extraction method.

## Evaluation

Evaluate the model’s predictions using the `radiographic_agreement` metric to measure inter-rater agreement between true and predicted labels.

```python
# Placeholder for predictions (replace with actual model predictions)
y_true = [sample["label"] for sample in samples]
y_pred = model.predict(dataset) # Adjust based on actual model

# Compute agreement
agreement_metrics = radiographic_agreement(y_true, y_pred)
print(f"Cohen's Kappa: {agreement_metrics['kappa']:.3f}")
print(f"Percent Agreement: {agreement_metrics['percent_agreement']:.2f}%")
```

### Expected Output
- `Cohen's Kappa`: A value between -1 and 1, where 1 indicates perfect agreement.
- `Percent Agreement`: Percentage of matching labels (0-100%).

## Conclusion

This example demonstrates how PyHealth can be used to classify chest X-ray abnormalities using the CheXpert dataset. The `chest_xray_classification_fn` task simplifies data preprocessing, while the `radiographic_agreement` metric provides a robust evaluation of model performance. Future work could integrate these outputs with UniXGen’s generated X-rays for enhanced analysis.

## References

- Lee, Hyungyung, et al. "Vision-Language Generative Model for View-Specific Chest X-Ray Generation." arXiv preprint arXiv:2302.12172, 2023.
- PyHealth Documentation: [https://pyhealth.readthedocs.io/](https://pyhealth.readthedocs.io/)
54 changes: 54 additions & 0 deletions pyhealth/metrics/radiographic_agreement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
from sklearn.metrics import cohen_kappa_score
from typing import List, Union

def radiographic_agreement(y_true: List[Union[int, List[int]]], y_pred: List[Union[int, List[int]]]) -> Dict[str, float]:
"""Calculates inter-rater agreement metrics for radiographic findings.

This function computes Cohen's Kappa and percentage agreement between true and predicted
radiographic labels (e.g., presence of pneumonia, edema). It supports both single-label
and multi-label cases.

Args:
y_true (List[Union[int, List[int]]]): Ground truth labels (0 or 1 for single-label,
list of 0/1 for multi-label per finding).
y_pred (List[Union[int, List[int]]]): Predicted labels matching y_true format.

Returns:
Dict[str, float]: Dictionary containing:
- "kappa": Cohen's Kappa score (range [-1, 1], 1 = perfect agreement).
- "percent_agreement": Percentage of matching labels (range [0, 100]).

Raises:
ValueError: If y_true and y_pred lengths or formats mismatch.
"""
if len(y_true) != len(y_pred):
raise ValueError("Length of y_true and y_pred must match.")

# Flatten multi-label lists if present
y_true_flat = [item if isinstance(item, int) else int(np.any(np.array(item))) for item in y_true]
y_pred_flat = [item if isinstance(item, int) else int(np.any(np.array(item))) for item in y_pred]

# Compute Cohen's Kappa
kappa = cohen_kappa_score(y_true_flat, y_pred_flat)

# Compute percentage agreement
agreement = sum(1 for t, p in zip(y_true_flat, y_pred_flat) if t == p) / len(y_true_flat) * 100

return {
"kappa": kappa,
"percent_agreement": agreement
}

if __name__ == "__main__":
# Test with single-label data
y_true_single = [1, 0, 1, 0]
y_pred_single = [1, 0, 0, 1]
result_single = radiographic_agreement(y_true_single, y_pred_single)
print("Single-label results:", result_single)

# Test with multi-label data (e.g., multiple findings)
y_true_multi = [[1, 0], [0, 1], [1, 1], [0, 0]] # [pneumonia, edema]
y_pred_multi = [[1, 0], [0, 0], [1, 0], [0, 1]]
result_multi = radiographic_agreement(y_true_multi, y_pred_multi)
print("Multi-label results:", result_multi)
88 changes: 88 additions & 0 deletions pyhealth/tasks/chest_xray_classification_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
from typing import List, Dict, Any

def chest_xray_classification_fn(patient: Any) -> List[Dict[str, Any]]:
"""Task function for classifying abnormalities in chest X-rays.

This function processes patient visit data to identify chest X-ray images and labels them
based on the presence of specific diagnoses (e.g., pneumonia, edema). It follows the
template structure of readmission_prediction tasks in PyHealth.

Args:
patient: A PyHealth Patient object containing visit and event data.
Expected to have `visits` attribute with event data including
'xray_images', 'view_position', and 'diagnoses'.

Returns:
List[Dict[str, Any]]: A list of samples, each containing:
- patient_id (str): Unique patient identifier.
- visit_id (str): Unique visit identifier.
- xray_path (str): Path to the chest X-ray image.
- view_position (str): View position of the X-ray (e.g., PA, AP, LATERAL, or UNKNOWN).
- label (int): Binary label (1 if pneumonia or edema is present, 0 otherwise).

Raises:
KeyError: If required event data (e.g., 'xray_images') is missing.
ValueError: If the patient object structure is invalid.
"""
samples = []
try:
if not hasattr(patient, 'visits') or not patient.visits:
raise ValueError("Patient object has no visits or visits is empty.")

for visit in patient.visits:
if not hasattr(visit, 'events'):
logging.warning(f"Visit {visit.visit_id} has no events data, skipping.")
continue

if "xray_images" not in visit.events:
logging.warning(f"No xray_images in visit {visit.visit_id}, skipping.")
continue

diagnoses = visit.events.get("diagnoses", [])
label = 1 if any(d in diagnoses for d in ["pneumonia", "edema"]) else 0

sample = {
"patient_id": patient.patient_id,
"visit_id": visit.visit_id,
"xray_path": visit.events["xray_images"],
"view_position": visit.events.get("view_position", "UNKNOWN"),
"label": label
}
samples.append(sample)

if not samples:
logging.warning(f"No valid samples generated for patient {patient.patient_id}.")
return samples

except KeyError as e:
logging.error(f"Missing required key in patient data: {e}")
raise
except Exception as e:
logging.error(f"Unexpected error in chest_xray_classification_fn: {e}")
raise

# Example usage (for testing)
if __name__ == "__main__":
# Mock patient object for testing
class MockVisit:
def __init__(self, visit_id, events):
self.visit_id = visit_id
self.events = events

class MockPatient:
def __init__(self, patient_id, visits):
self.patient_id = patient_id
self.visits = visits

# Test data
test_visits = [
MockVisit("v1", {"xray_images": "/path/to/xray1.jpg", "view_position": "PA", "diagnoses": ["pneumonia"]}),
MockVisit("v2", {"xray_images": "/path/to/xray2.jpg", "view_position": "AP", "diagnoses": ["fever"]}),
MockVisit("v3", {"xray_images": "/path/to/xray3.jpg", "diagnoses": ["edema"]})
]
test_patient = MockPatient("p1", test_visits)

samples = chest_xray_classification_fn(test_patient)
for sample in samples:
print(sample)