Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pyhealth/datasets/configs/echo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: "1.0"

tables:
echo_images:
file_path: "TMED2SummaryTable.csv"
patient_id: null
timestamp: null
attributes:
- "patient_study"
- "diagnosis_label"
109 changes: 109 additions & 0 deletions pyhealth/datasets/echo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset
from pyhealth.datasets import BaseDataset
import os
DiagnosisStr_to_Int_Mapping = {
'no_AS': 0,
'mild_AS': 1,
'mildtomod_AS': 1,
'moderate_AS': 2,
'severe_AS': 2
}

class EchoBagDataset(BaseDataset):
""" Echocardiogram Aortic Stenosis Bagged Dataset for Multiple Instance Learning.

This dataset consists of echocardiogram images grouped into
"bags" per patient study, with corresponding aortic stenosis severity labels.
Each bag contains multiple grayscale images , and is labeled
according to the severity of aortic stenosis diagnosed by clinical experts.

Data is organized into images and a summary CSV containing patient study IDs
and diagnosis labels. The dataset can be used for multiple instance learning
(MIL) tasks such as study-level classification.

Dataset is available at:
https://tmed.cs.tufts.edu/data_access.html

Args:
root_dir: Root directory containing the image files and summary CSV.
dataset_name: Name of the dataset. Defaults to "echocardiogram_as".
config_path: Path to the dataset configuration YAML file. If None, uses default config.
transform_fn: Optional image transformation function to apply to each image.
sampling_strategy: Strategy for sampling images in each study ("first_frame", etc).
training_seed: Random seed for any randomized operations.

Attributes:
root_dir: Directory containing the raw images and summary table.
dataset_name: Name of the dataset.
config_path: Path to the YAML configuration file.
transform_fn: Image transformation function.
summary_table: DataFrame containing patient study IDs and diagnosis labels.
bag_of_PatientStudy_images: List of image bags
bag_of_PatientStudy_DiagnosisLabels: Corresponding labels for each image bag.

Examples:
>>> from pyhealth.datasets import EchoBagDataset
>>> dataset = EchoBagDataset(
... root_dir="path/to/echobag",
... config_path="path/to/echobag.yaml",
... transform_fn=some_transform
... )
>>> dataset.stats()
>>> bag, label = dataset[0]
>>> print(bag.shape, label)
"""

def __init__(self, root_dir, summary_table, transform_fn=None, sampling_strategy="first_frame"):


config_path = os.path.join(os.path.dirname(__file__), "configs", "echo.yaml")
super().__init__(
root=root_dir,
tables=["echo_images"],
dataset_name="EchoBagDataset",
config_path=config_path,
)
self.root_dir = root_dir
self.summary_table = summary_table
self.transform_fn = transform_fn
self.sampling_strategy = sampling_strategy
self.patient_studies = self.summary_table['patient_study'].unique()
self.data, self.labels = self._create_bags()

def _create_bags(self):
data, labels = [], []
for study in self.patient_studies:
study_dir = os.path.join(self.root_dir)
images = sorted([f for f in os.listdir(study_dir) if study in f and f.endswith(".png")])
if len(images) == 0:
continue

bag_images = []
for img_file in images:
img_path = os.path.join(study_dir, img_file)
img = np.array(Image.open(img_path).convert("RGB"))
assert img.shape == (112, 112, 3), f"Image shape error: {img.shape}"
bag_images.append(img)

bag_images = np.array(bag_images)
data.append(bag_images)

label_str = self.summary_table[self.summary_table['patient_study'] == study]['diagnosis_label'].iloc[0]
label = DiagnosisStr_to_Int_Mapping[label_str]
labels.append(label)
return data, labels

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
bag = self.data[idx]
if self.transform_fn:
bag = torch.stack([self.transform_fn(Image.fromarray(img)) for img in bag])
label = self.labels[idx]
return bag, label
62 changes: 62 additions & 0 deletions pyhealth/unittests/test_datasets/test_echo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import shutil
import tempfile
import unittest
import numpy as np
import pandas as pd
from PIL import Image
from pyhealth.datasets import EchoBagDataset
class TestEchoDataset(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()

self.study_ids = ["patient1_study1", "patient2_study1"]
for study in self.study_ids:
for i in range(3):
img = Image.fromarray(np.random.randint(0, 255, (112, 112, 3), dtype=np.uint8))
img.save(os.path.join(self.temp_dir, f"{study}_{i}.png"))

data = {
"patient_study": self.study_ids,
"diagnosis_label": ["no_AS", "severe_AS"],
}
self.summary_table = pd.DataFrame(data)

self.summary_csv_path = os.path.join(self.temp_dir, "TMED2SummaryTable.csv")
self.summary_table.to_csv(self.summary_csv_path, index=False)


def tearDown(self):
shutil.rmtree(self.temp_dir)

def test_dataset_loading(self):
dataset = EchoBagDataset(
root_dir=self.temp_dir,
summary_table=self.summary_table,
transform_fn=None,
sampling_strategy="first_frame"
)

self.assertEqual(len(dataset), 2)

images, label = dataset[0]
self.assertEqual(images.shape[1:], (112, 112, 3))
self.assertTrue(label in [0, 2])
self.assertEqual(images.shape[0], 3)

def test_get_item(self):
dataset = EchoBagDataset(
root_dir=self.temp_dir,
summary_table=self.summary_table
)

images, label = dataset[0]

self.assertIsInstance(images, np.ndarray)
self.assertEqual(images.shape, (3, 112, 112, 3))
self.assertIsInstance(label, int)
self.assertIn(label, [0, 2])
self.assertTrue(np.all(images >= 0) and np.all(images <= 255))

if __name__ == "__main__":
unittest.main(verbosity=2)