Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Iterator, List, Optional
from typing import Any, Iterator, List, Optional
from urllib.parse import urlparse, urlunparse

import polars as pl
Expand Down Expand Up @@ -126,6 +126,8 @@ def __init__(
# Cached attributes
self._collected_global_event_df = None
self._unique_patient_ids = None
# Cache for sample datasets by task name
self._sample_dataset_cache = {}

@property
def collected_global_event_df(self) -> pl.DataFrame:
Expand Down Expand Up @@ -304,13 +306,13 @@ def get_patient(self, patient_id: str) -> Patient:
Raises:
AssertionError: If the patient ID is not found in the dataset.
"""
assert (
patient_id in self.unique_patient_ids
), f"Patient {patient_id} not found in dataset"
assert patient_id in self.unique_patient_ids, (
f"Patient {patient_id} not found in dataset"
)
df = self.collected_global_event_df.filter(pl.col("patient_id") == patient_id)
return Patient(patient_id=patient_id, data_source=df)

def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]:
def iter_patients(self, df: Optional[pl.DataFrame] = None) -> Iterator[Patient]:
"""Yields Patient objects for each unique patient in the dataset.

Yields:
Expand Down Expand Up @@ -342,7 +344,10 @@ def default_task(self) -> Optional[BaseTask]:
return None

def set_task(
self, task: Optional[BaseTask] = None, num_workers: int = 1
self,
task: Optional[BaseTask] = None,
num_workers: int = 1,
use_cache: bool = True,
) -> SampleDataset:
"""Processes the base dataset to generate the task-specific sample dataset.

Expand All @@ -351,6 +356,8 @@ def set_task(
num_workers (int): Number of workers for multi-threading. Default is 1.
This is because the task function is usually CPU-bound. And using
multi-threading may not speed up the task function.
use_cache (bool): Whether to cache the generated sample dataset and/or.
use existing cached datasets for the task. Default is True.

Returns:
SampleDataset: The generated sample dataset.
Expand All @@ -362,6 +369,10 @@ def set_task(
assert self.default_task is not None, "No default tasks found"
task = self.default_task

if use_cache and task.task_name in self._sample_dataset_cache:
logger.info(f"Using cached sample dataset for task {task.task_name}")
return self._sample_dataset_cache[task.task_name]

logger.info(
f"Setting task {task.task_name} for {self.dataset_name} base dataset..."
)
Expand Down Expand Up @@ -395,8 +406,10 @@ def set_task(
input_schema=task.input_schema,
output_schema=task.output_schema,
dataset_name=self.dataset_name,
task_name=task,
task_name=task.task_name,
)

logger.info(f"Generated {len(samples)} samples for task {task.task_name}")
if use_cache:
self._sample_dataset_cache[task.task_name] = sample_dataset
return sample_dataset
19 changes: 9 additions & 10 deletions pyhealth/datasets/sample_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,22 @@ def __init__(
# Create patient_to_index and record_to_index mappings
self.patient_to_index = {}
self.record_to_index = {}

for i, sample in enumerate(samples):
# Create patient_to_index mapping
patient_id = sample.get('patient_id')
patient_id = sample.get("patient_id")
if patient_id is not None:
if patient_id not in self.patient_to_index:
self.patient_to_index[patient_id] = []
self.patient_to_index[patient_id].append(i)

# Create record_to_index mapping (optional)
record_id = sample.get('record_id', sample.get('visit_id'))
record_id = sample.get("record_id", sample.get("visit_id"))
if record_id is not None:
if record_id not in self.record_to_index:
self.record_to_index[record_id] = []
self.record_to_index[record_id].append(i)

self.validate()
self.build()

Expand All @@ -72,11 +72,10 @@ def validate(self) -> None:
input_keys = set(self.input_schema.keys())
output_keys = set(self.output_schema.keys())
for s in self.samples:
assert input_keys.issubset(s.keys()), \
"Input schema does not match samples."
assert output_keys.issubset(s.keys()), \
"Output schema does not match samples."
return
assert input_keys.issubset(s.keys()), "Input schema does not match samples."
assert output_keys.issubset(
s.keys()
), "Output schema does not match samples."

def build(self) -> None:
"""Builds the processors for input and output data based on schemas."""
Expand Down
4 changes: 2 additions & 2 deletions pyhealth/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class BaseTask(ABC):
input_schema: Dict[str, str]
output_schema: Dict[str, str]

def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
def pre_filter(self, df: pl.DataFrame) -> pl.DataFrame:
return df

@abstractmethod
def __call__(self, patient) -> List[Dict]:
raise NotImplementedError
raise NotImplementedError
28 changes: 9 additions & 19 deletions pyhealth/tasks/benchmark_ehrshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,23 @@ class BenchmarkEHRShot(BaseTask):
"""Benchmark predictive tasks using EHRShot."""

tasks = {
"operational_outcomes": [
"guo_los",
"guo_readmission",
"guo_icu"
],
"operational_outcomes": ["guo_los", "guo_readmission", "guo_icu"],
"lab_values": [
"lab_thrombocytopenia",
"lab_hyperkalemia",
"lab_hypoglycemia",
"lab_hyponatremia",
"lab_anemia"
"lab_anemia",
],
"new_diagnoses": [
"new_hypertension",
"new_hyperlipidemia",
"new_pancan",
"new_celiac",
"new_lupus",
"new_acutemi"
"new_acutemi",
],
"chexpert": [
"chexpert"
]
"chexpert": ["chexpert"],
}

def __init__(self, task: str, omop_tables: Optional[List[str]] = None) -> None:
Expand All @@ -53,13 +47,13 @@ def __init__(self, task: str, omop_tables: Optional[List[str]] = None) -> None:
self.output_schema = {"label": "binary"}
elif task in self.tasks["chexpert"]:
self.output_schema = {"label": "multilabel"}
def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:

def pre_filter(self, df: pl.DataFrame) -> pl.DataFrame:
if self.omop_tables is None:
return df
filtered_df = df.filter(
(pl.col("event_type") != "ehrshot") |
(pl.col("ehrshot/omop_table").is_in(self.omop_tables))
(pl.col("event_type") != "ehrshot")
| (pl.col("ehrshot/omop_table").is_in(self.omop_tables))
)
return filtered_df

Expand All @@ -81,9 +75,5 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
label_value = int(label_value)
label_value = [i for i in range(14) if (label_value >> i) & 1]
label_value = [13 - i for i in label_value[::-1]]
samples.append({
"feature": codes,
"label": label_value,
"split": split
})
samples.append({"feature": codes, "label": label_value, "split": split})
return samples
Loading