Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions pyhealth/datasets/eeg_seizure_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# ------------------------------------------------------------------------------
# Author: Subin Pradeep & Utkarsh Prasad
# NetID: subinpp2 & uprasad3
# Description: BaseDataset wrapper for synthetic EEG seizure windows, with
# labels (1=seizure, 0=background) and subject IDs.
# ------------------------------------------------------------------------------

import os
import pickle
from typing import Tuple

import numpy as np

from pyhealth.datasets.base_dataset_v2 import BaseDatasetV2


class EEGSeizureDataset(BaseDatasetV2):
"""
EEG Seizure Detection Dataset.

Loads preprocessed EEG windows from a pickle file generated by
preprocess.py in Real‑Time‑EEG‑Seizure‑Detection.

Each sample is a 19×T float array, with a binary label and subject ID.
"""

def __init__(self, preproc_file: str):
"""
Args:
preproc_file (str): Path to pickled dict with keys:
- 'eeg' : np.ndarray, shape (N, 19, T)
- 'label' : np.ndarray, shape (N,), {0,1}
- 'tag' : np.ndarray, shape (N,), fine‑grained onset/offset
- 'subj' : np.ndarray, shape (N,), subject IDs
"""
super().__init__()
if not os.path.exists(preproc_file):
raise FileNotFoundError(f"Cannot find {preproc_file}")
with open(preproc_file, "rb") as f:
data = pickle.load(f)

self.X: np.ndarray = data["eeg"]
self.y: np.ndarray = data["label"]
self.tags: np.ndarray = data["tag"]
self.subj: np.ndarray = data["subj"]

def __len__(self) -> int:
"""Total number of windows."""
return len(self.y)

def __getitem__(self, idx: int) -> Tuple[np.ndarray, int, int]:
"""
Args:
idx (int): Index of the sample.
Returns:
tuple:
- X (np.ndarray): EEG window (19×T).
- y (int): Label, 1 for seizure, 0 for background.
- subj (int): Subject ID of this window.
"""
return self.X[idx], int(self.y[idx]), int(self.subj[idx])
71 changes: 71 additions & 0 deletions pyhealth/models/seizure_crnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# ------------------------------------------------------------------------------
# Author: Subin Pradeep & Utkarsh Prasad
# NetID: subinpp2 & uprasad3
# Description: CRNN for EEG Seizure Detection (Conv → GRU → FC)
# ------------------------------------------------------------------------------

import torch
import torch.nn as nn


class SeizureCRNN(nn.Module):
"""
A simple convolutional‑recurrent network for binary seizure detection.
"""

def __init__(
self,
in_channels: int = 19,
num_classes: int = 2,
hidden_size: int = 128,
num_layers: int = 1,
):
super(SeizureCRNN, self).__init__()

# Convolutional encoder
self.encoder = nn.Sequential(
nn.Conv1d(in_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.MaxPool1d(2),

nn.Conv1d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.MaxPool1d(2),
)

# Bidirectional GRU
self.gru = nn.GRU(
input_size=64,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
)

# Classification head
self.classifier = nn.Linear(hidden_size * 2, num_classes)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape (batch, channels=19, time_steps)
Returns:
logits: Tensor of shape (batch, num_classes)
"""
# Conv → (batch, 64, T')
x = self.encoder(x)

# Repack for RNN: (batch, T', features)
x = x.permute(0, 2, 1)

# GRU → (batch, T', hidden*2)
out, _ = self.gru(x)

# Take final time step
final = out[:, -1, :]

# Classify
logits = self.classifier(final)
return logits
111 changes: 111 additions & 0 deletions pyhealth/tasks/seizure_detection_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# ------------------------------------------------------------------------------
# Author: Subin Pradeep & Utkarsh Prasad
# NetID: subinpp2 & uprasad3
# Description: Seizure detection classification task using EEGSeizureDataset.
# ------------------------------------------------------------------------------

from typing import Any, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

from pyhealth.tasks.task_template import TaskTemplate
from pyhealth.datasets.eeg_seizure_dataset import EEGSeizureDataset
from pyhealth.models.seizure_crnn import SeizureCRNN


class SeizureDetectionTask(TaskTemplate):
"""
Task: Binary classification of seizure vs. background EEG windows.

Splits by subject: last K subjects held out as test set.
"""

def __init__(
self,
preproc_file: str,
holdout_subjects: int = 4,
batch_size: int = 64,
lr: float = 1e-3,
weight_decay: float = 5e-5,
epochs: int = 30,
device: Optional[str] = None,
):
super().__init__()

# Load dataset
self.dataset = EEGSeizureDataset(preproc_file)
subs = np.unique(self.dataset.subj)
test_subs = subs[-holdout_subjects :]

# Train/test split by subject
all_idx = np.arange(len(self.dataset))
train_idx = [i for i in all_idx if self.dataset.subj[i] not in test_subs]
test_idx = [i for i in all_idx if self.dataset.subj[i] in test_subs]

self.train_loader = DataLoader(
Subset(self.dataset, train_idx), batch_size=batch_size, shuffle=True
)
self.test_loader = DataLoader(
Subset(self.dataset, test_idx), batch_size=batch_size
)

# Device
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

# Model, loss, optimizer
self.model = SeizureCRNN().to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=lr, weight_decay=weight_decay
)

self.epochs = epochs

def train(self) -> Dict[str, Any]:
"""Train for the configured number of epochs; return training history."""
self.model.train()
history = {"loss": [], "acc": []}

for epoch in range(1, self.epochs + 1):
epoch_loss = 0.0
correct = 0

for X, y, _ in self.train_loader:
X, y = X.to(self.device), y.to(self.device)
self.optimizer.zero_grad()

logits = self.model(X)
loss = self.criterion(logits, y)
loss.backward()
self.optimizer.step()

epoch_loss += loss.item() * X.size(0)
preds = logits.argmax(dim=1)
correct += (preds == y).sum().item()

avg_loss = epoch_loss / len(self.train_loader.dataset)
acc = correct / len(self.train_loader.dataset)

history["loss"].append(avg_loss)
history["acc"].append(acc)
print(f"[Train] Epoch {epoch}/{self.epochs} — loss={avg_loss:.4f} acc={acc:.4f}")

return history

def evaluate(self) -> Dict[str, Any]:
"""Evaluate on held‑out subjects; return test metrics."""
self.model.eval()
correct = 0
total = 0

with torch.no_grad():
for X, y, _ in self.test_loader:
X, y = X.to(self.device), y.to(self.device)
preds = self.model(X).argmax(dim=1)
correct += (preds == y).sum().item()
total += y.size(0)

return {"test_acc": correct / total}
32 changes: 32 additions & 0 deletions pyhealth/unittests/test_eeg_seizure_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pickle
import numpy as np
import pytest

from pyhealth.datasets.eeg_seizure_dataset import EEGSeizureDataset


def make_dummy(tmp_path):
arr = np.zeros((5, 19, 800), dtype=float)
data = {
"eeg": arr,
"label": np.array([0, 1, 0, 1, 0], dtype=int),
"tag": np.zeros(5),
"subj": np.array([1, 1, 2, 2, 3], dtype=int),
}
p = tmp_path / "dummy.pkl"
with open(p, "wb") as f:
pickle.dump(data, f)
return str(p)


def test_length_and_getitem(tmp_path):
fp = make_dummy(tmp_path)
ds = EEGSeizureDataset(fp)

assert len(ds) == 5

X, y, subj = ds[1]
assert isinstance(X, np.ndarray)
assert X.shape == (19, 800)
assert y == 1
assert subj == 1
30 changes: 30 additions & 0 deletions pyhealth/unittests/test_seizure_detection_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pickle
import numpy as np
import pytest

from pyhealth.tasks.seizure_detection_task import SeizureDetectionTask


def make_dummy(tmp_path):
arr = np.random.randn(20, 19, 800).astype(float)
labels = np.array([0, 1] * 10, dtype=int)
tags = np.zeros(20)
subjs = np.repeat(np.arange(4), 5) # 4 subjects × 5 samples each
data = {"eeg": arr, "label": labels, "tag": tags, "subj": subjs}
p = tmp_path / "dummy.pkl"
with open(p, "wb") as f:
pickle.dump(data, f)
return str(p)


def test_task_runs(tmp_path):
pkl = make_dummy(tmp_path)

# holdout 1 subject for test, 1 epoch only
task = SeizureDetectionTask(pkl, holdout_subjects=1, epochs=1, batch_size=4)
history = task.train()
assert "loss" in history and "acc" in history

metrics = task.evaluate()
assert "test_acc" in metrics
assert 0.0 <= metrics["test_acc"] <= 1.0