Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 147 additions & 7 deletions pyhealth/processors/timeseries_processor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from datetime import datetime, timedelta
from typing import Any, List, Tuple

import numpy as np
import torch

from . import register_processor
from .base_processor import FeatureProcessor


@register_processor("timeseries")
class TimeseriesProcessor(FeatureProcessor):
"""
Expand All @@ -25,12 +21,20 @@ class TimeseriesProcessor(FeatureProcessor):
- torch.Tensor of shape (S, F), where S is the number of sampled time steps.
"""

def __init__(self, sampling_rate: timedelta = timedelta(hours=1), impute_strategy: str = "forward_fill"):
def __init__(
self, sampling_rate: timedelta = timedelta(hours=1),
impute_strategy: str = "forward_fill",
normalize: bool = False,
norm_method: str = "z_score",
norm_axis: str = "global"
):
# Configurable sampling rate and imputation method
self.sampling_rate = sampling_rate
self.impute_strategy = impute_strategy
self.size = None

self.normalize_method = norm_method
self.normalize_axis = norm_axis
self.normalize_flag = normalize
def process(self, value: Tuple[List[datetime], np.ndarray]) -> torch.Tensor:
timestamps, values = value

Expand Down Expand Up @@ -70,9 +74,10 @@ def process(self, value: Tuple[List[datetime], np.ndarray]) -> torch.Tensor:

if self.size is None:
self.size = sampled_values.shape[1]
if self.normalize_flag and hasattr(self, 'mean_'):
sampled_values = self._apply_normalization(sampled_values)

return torch.tensor(sampled_values, dtype=torch.float)

def size(self):
# Size equals number of features, unknown until first process
return self.size
Expand All @@ -82,3 +87,138 @@ def __repr__(self):
f"TimeSeriesProcessor(sampling_rate={self.sampling_rate}, "
f"impute_strategy='{self.impute_strategy}')"
)
def _compute_global_stats(self, data: np.ndarray) -> Any:
"""
Compute global statistics for normalization across the entire dataset.

Depending on `self.normalize_method`, calculates:
- "z_score": mean and standard deviation over all values.
- "min_max": minimum and maximum over all values.
- "robust": median and median absolute deviation (MAD) over all values.

Parameters
----------
data : np.ndarray
The input array containing all values to compute statistics on.

Raises
------
ValueError
If `self.normalize_method` is unsupported.
"""
if self.normalize_method == "z_score":
self.mean = np.mean(data)
self.std = np.std(data)
elif self.normalize_method == "min_max":
self.min = np.min(data)
self.max = np.max(data)
elif self.normalize_method == "robust" :
self.median = np.median(data)
self.mad_ = np.median(np.abs(data - self.median))
else:
raise ValueError(f"Unsupported normalization method: {self.normalize_method}")
def _compute_per_feature_stats(self, data: np.ndarray) -> Any:
"""
Compute per-feature statistics for normalization.

Calculates statistics independently for each column (feature) based on
`self.normalize_method`:
- "z_score": mean and standard deviation per feature.
- "min_max": minimum and maximum per feature.
- "robust": median and median absolute deviation (MAD) per feature.

Parameters
----------
data : np.ndarray
The input 2D array where each column represents a feature.

Raises
------
ValueError
If `self.normalize_method` is unsupported.
"""
if self.normalize_method == "z_score":
self.mean = np.mean(data, axis=0)
self.std = np.std(data, axis=0)
elif self.normalize_method == "min_max":
self.min = np.min(data, axis=0)
self.max = np.max(data, axis=0)
elif self.normalize_method == "robust" :
self.median = np.median(data, axis=0)
self.mad_ = np.median(np.abs(data - self.median), axis=0)
else:
raise ValueError(f"Unsupported normalization method: {self.normalize_method}")
def fit(self, samples: List[dict[str, Any]], field: str) -> None:
"""
Fit normalization statistics to a dataset.

Extracts values from the given `samples`, processes them, and computes
normalization statistics either globally or per feature, depending on
`self.normalize_axis`.

Parameters
----------
samples : list of dict[str, Any]
A list of sample dictionaries. Each dictionary must contain `field`
mapping to a tuple of (timestamps, values).
field : str
The key in each sample dictionary from which to extract the data.

Notes
-----
- Uses `self.process()` to preprocess each sample's values before computing statistics.
- Does nothing if `self.normalize_flag` is False.

Raises
------
ValueError
If `self.normalize_axis` is unsupported.
"""
if not self.normalize_flag:
return
all_values = []
for sample in samples:
timestamps, values = sample[field]
processed_values = self.process((timestamps, values))
all_values.append(processed_values.numpy())
combined_values = np.vstack(all_values)
if self.normalize_axis == "global":
self._compute_global_stats(combined_values)
elif self.normalize_axis == "per_feature":
self._compute_per_feature_stats(combined_values)
else:
raise ValueError(f"Unsupported normalization axis: {self.normalize_axis}")
def _apply_normalization(self, value: np.ndarray) -> np.ndarray:
"""
Apply normalization to an array using precomputed statistics.

Normalization method is determined by `self.normalize_method`:
- "z_score": (value - mean) / std
- "min_max": (value - min) / (max - min)
- "robust": (value - median) / MAD

Parameters
----------
value : np.ndarray
The array of values to normalize.

Returns
-------
np.ndarray
The normalized array.

Raises
------
ValueError
If `self.normalize_method` is unsupported.
"""
if self.normalize_method == "z_score":
return (value - self.mean) / (self.std + 1e-8)
elif self.normalize_method == "min_max":
return (value - self.min) / (self.max - self.min + 1e-8)
elif self.normalize_method == "robust":
return (value - self.median) / (self.mad_ + 1e-8)
else:
raise ValueError(f"Unsupported normalization method: {self.normalize_method}")


122 changes: 122 additions & 0 deletions pyhealth/unittests/test_advanced_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python3
"""
Advanced test of TimeseriesProcessor normalization methods
"""
import numpy as np
from datetime import datetime, timedelta
from pyhealth.processors.timeseries_processor import TimeseriesProcessor

def create_diverse_data():
"""Create more diverse test data"""
# Sample 1: Patient with high blood pressure
timestamps1 = [
datetime(2023, 1, 1, 8, 0),
datetime(2023, 1, 1, 12, 0),
datetime(2023, 1, 1, 16, 0),
datetime(2023, 1, 1, 20, 0),
]
values1 = np.array([
[160, 100], # Very high BP
[155, 98],
[162, 102],
[158, 99]
])

# Sample 2: Patient with normal blood pressure
timestamps2 = [
datetime(2023, 1, 1, 9, 0),
datetime(2023, 1, 1, 13, 0),
datetime(2023, 1, 1, 17, 0),
]
values2 = np.array([
[120, 80], # Normal BP
[118, 78],
[122, 82]
])

# Sample 3: Patient with low blood pressure
timestamps3 = [
datetime(2023, 1, 1, 10, 0),
datetime(2023, 1, 1, 14, 0),
]
values3 = np.array([
[90, 60], # Low BP
[95, 65]
])

return [
{"vitals": (timestamps1, values1)},
{"vitals": (timestamps2, values2)},
{"vitals": (timestamps3, values3)}
]

def test_normalization_methods():
print("Testing different normalization methods...")

samples = create_diverse_data()

# Test Min-Max normalization
print("\n1. Min-Max Normalization (Global):")
processor_minmax = TimeseriesProcessor(
sampling_rate=timedelta(hours=2),
impute_strategy="forward_fill",
normalize=True,
norm_method="min_max",
norm_axis="global"
)

processor_minmax.fit(samples, "vitals")
print(f" Global min: {processor_minmax.min:.2f}")
print(f" Global max: {processor_minmax.max:.2f}")

result_minmax = processor_minmax.process(samples[0]["vitals"])
print(f" Min-max normalized (should be in [0,1]):\n{result_minmax}")
print(f" Range: [{result_minmax.min():.3f}, {result_minmax.max():.3f}]")

# Test Robust normalization
print("\n2. Robust Normalization (Per-feature):")
processor_robust = TimeseriesProcessor(
sampling_rate=timedelta(hours=2),
impute_strategy="forward_fill",
normalize=True,
norm_method="robust",
norm_axis="per_feature"
)

processor_robust.fit(samples, "vitals")
print(f" Per-feature medians: {processor_robust.median}")
print(f" Per-feature MADs: {processor_robust.mad_}")

result_robust = processor_robust.process(samples[0]["vitals"])
print(f" Robust normalized:\n{result_robust}")

# Test that statistics come from training set only
print("\n3. Training vs Test Set Statistics:")
train_samples = samples[:2] # First 2 samples as "training"
test_sample = samples[2] # Last sample as "test"

processor_train = TimeseriesProcessor(
normalize=True,
norm_method="z_score",
norm_axis="per_feature"
)

# Fit only on training data
processor_train.fit(train_samples, "vitals")
train_mean = processor_train.mean.copy()
train_std = processor_train.std.copy()

print(f" Training set statistics:")
print(f" Mean: {train_mean}")
print(f" Std: {train_std}")

# Process test sample with training statistics
test_normalized = processor_train.process(test_sample["vitals"])
print(f" Test sample (low BP) normalized with training stats:")
print(f" {test_normalized}")
print(f" -> Should show negative values (below training mean)")

print("\n✓ All advanced normalization tests completed!")

if __name__ == "__main__":
test_normalization_methods()
Loading