Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

version = importlib.metadata.version("pySEQTarget")
if not version:
version = "0.11.0"
version = "0.12.0"
sys.path.insert(0, os.path.abspath("../"))

project = "pySEQTarget"
Expand Down
182 changes: 89 additions & 93 deletions pySEQTarget/analysis/_risk_estimates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,105 @@
from scipy import stats


def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None):
"""
Compute Risk Difference and Risk Ratio from a comparison dataframe.
Consolidates the repeated calculation logic.
"""
if group_cols is None:
group_cols = []

if has_bootstrap:
rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
rd_comp = comp.with_columns(
[
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
(pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias("RD 95% LCI"),
(pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias("RD 95% UCI"),
]
)
rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
col_order = group_cols + [
"A_x",
"A_y",
"Risk Difference",
"RD 95% LCI",
"RD 95% UCI",
]
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])

rr_log_se = (
(pl.col("se_x") / pl.col("risk_x")).pow(2)
+ (pl.col("se_y") / pl.col("risk_y")).pow(2)
).sqrt()
rr_comp = comp.with_columns(
[
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
(
(pl.col("risk_x") / pl.col("risk_y")) * (-z * rr_log_se).exp()
).alias("RR 95% LCI"),
(
(pl.col("risk_x") / pl.col("risk_y")) * (z * rr_log_se).exp()
).alias("RR 95% UCI"),
]
)
rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
col_order = group_cols + ["A_x", "A_y", "Risk Ratio", "RR 95% LCI", "RR 95% UCI"]
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
else:
rd_comp = comp.with_columns(
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
)
rd_comp = rd_comp.drop(["risk_x", "risk_y"])
col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])

rr_comp = comp.with_columns(
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
)
rr_comp = rr_comp.drop(["risk_x", "risk_y"])
col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])

return rd_comp, rr_comp


def _risk_estimates(self):
last_followup = self.km_data["followup"].max()
risk = self.km_data.filter(
(pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
)

group_cols = [self.subgroup_colname] if self.subgroup_colname else []
rd_comparisons = []
rr_comparisons = []
has_bootstrap = self.bootstrap_nboot > 0

if self.bootstrap_nboot > 0:
if has_bootstrap:
alpha = 1 - self.bootstrap_CI
z = stats.norm.ppf(1 - alpha / 2)
else:
z = None

# Pre-extract data for each treatment level once (avoid repeated filtering)
risk_by_level = {}
for tx in self.treatment_level:
level_data = risk.filter(pl.col(self.treatment_col) == tx)
risk_by_level[tx] = {
"pred": level_data.select(group_cols + ["pred"]),
}
if has_bootstrap:
risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])

rd_comparisons = []
rr_comparisons = []

for tx_x in self.treatment_level:
for tx_y in self.treatment_level:
if tx_x == tx_y:
continue

risk_x = (
risk.filter(pl.col(self.treatment_col) == tx_x)
.select(group_cols + ["pred"])
.rename({"pred": "risk_x"})
)

risk_y = (
risk.filter(pl.col(self.treatment_col) == tx_y)
.select(group_cols + ["pred"])
.rename({"pred": "risk_y"})
)
# Use pre-extracted data instead of filtering again
risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})

if group_cols:
comp = risk_x.join(risk_y, on=group_cols, how="left")
Expand All @@ -42,18 +111,9 @@ def _risk_estimates(self):
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
)

if self.bootstrap_nboot > 0:
se_x = (
risk.filter(pl.col(self.treatment_col) == tx_x)
.select(group_cols + ["SE"])
.rename({"SE": "se_x"})
)

se_y = (
risk.filter(pl.col(self.treatment_col) == tx_y)
.select(group_cols + ["SE"])
.rename({"SE": "se_y"})
)
if has_bootstrap:
se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})

if group_cols:
comp = comp.join(se_x, on=group_cols, how="left")
Expand All @@ -62,73 +122,9 @@ def _risk_estimates(self):
comp = comp.join(se_x, how="cross")
comp = comp.join(se_y, how="cross")

rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
rd_comp = comp.with_columns(
[
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
(pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias(
"RD 95% LCI"
),
(pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias(
"RD 95% UCI"
),
]
)
rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
col_order = group_cols + [
"A_x",
"A_y",
"Risk Difference",
"RD 95% LCI",
"RD 95% UCI",
]
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
rd_comparisons.append(rd_comp)

rr_log_se = (
(pl.col("se_x") / pl.col("risk_x")).pow(2)
+ (pl.col("se_y") / pl.col("risk_y")).pow(2)
).sqrt()
rr_comp = comp.with_columns(
[
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
(
(pl.col("risk_x") / pl.col("risk_y"))
* (-z * rr_log_se).exp()
).alias("RR 95% LCI"),
(
(pl.col("risk_x") / pl.col("risk_y"))
* (z * rr_log_se).exp()
).alias("RR 95% UCI"),
]
)
rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
col_order = group_cols + [
"A_x",
"A_y",
"Risk Ratio",
"RR 95% LCI",
"RR 95% UCI",
]
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
rr_comparisons.append(rr_comp)

else:
rd_comp = comp.with_columns(
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
)
rd_comp = rd_comp.drop(["risk_x", "risk_y"])
col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
rd_comparisons.append(rd_comp)

rr_comp = comp.with_columns(
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
)
rr_comp = rr_comp.drop(["risk_x", "risk_y"])
col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
rr_comparisons.append(rr_comp)
rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)
rd_comparisons.append(rd_comp)
rr_comparisons.append(rr_comp)

risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()
Expand Down
16 changes: 6 additions & 10 deletions pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,20 @@ def _calculate_risk(self, data, idx=None, val=None):
lci = a / 2
uci = 1 - lci

# Pre-compute the followup range once (starts at 1, not 0)
followup_range = list(range(1, self.followup_max + 1))

SDT = (
data.with_columns(
[
(
pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)
).alias("TID")
]
[pl.concat_str([pl.col(self.id_col), pl.col("trial")]).alias("TID")]
)
.group_by("TID")
.first()
.drop(["followup", f"followup{self.indicator_squared}"])
.with_columns([pl.lit(list(range(self.followup_max))).alias("followup")])
.with_columns([pl.lit(followup_range).alias("followup")])
.explode("followup")
.with_columns(
[
(pl.col("followup") + 1).alias("followup"),
(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
]
[(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
)
).sort([self.id_col, "trial", "followup"])

Expand Down
15 changes: 4 additions & 11 deletions pySEQTarget/expansion/_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in
.with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")])
.with_columns(
[
pl.struct(
[
pl.col(time_col),
pl.col(time_col).max().over(id_col).alias("max_time"),
]
)
.map_elements(
lambda x: list(range(x[time_col], x["max_time"] + 1)),
return_dtype=pl.List(pl.Int64),
)
.alias("period")
pl.int_ranges(
pl.col(time_col),
pl.col(time_col).max().over(id_col) + 1,
).alias("period")
]
)
.explode("period")
Expand Down
20 changes: 16 additions & 4 deletions pySEQTarget/helpers/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def _prepare_boot_data(self, data, boot_id):


def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs):
obj = copy.deepcopy(obj)
# Shallow copy the object and only deep copy mutable state that changes per-bootstrap
obj = copy.copy(obj)
# Deep copy only the mutable attributes that get modified during fitting
obj.outcome_model = []
obj.numerator_model = copy.copy(obj.numerator_model) if hasattr(obj, 'numerator_model') and obj.numerator_model else []
obj.denominator_model = copy.copy(obj.denominator_model) if hasattr(obj, 'denominator_model') and obj.denominator_model else []

obj._rng = (
np.random.RandomState(seed + i) if seed is not None else np.random.RandomState()
)
Expand Down Expand Up @@ -104,13 +110,19 @@ def wrapper(self, *args, **kwargs):
self._rng = original_rng
self.DT = self._offloader.load_dataframe(original_DT_ref)
else:
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
del original_DT
# Keep original data in memory if offloading is disabled to avoid unnecessary I/O
if self._offloader.enabled:
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
del original_DT
else:
original_DT_ref = original_DT

for i in tqdm(range(nboot), desc="Bootstrapping..."):
self._current_boot_idx = i + 1
tmp = self._offloader.load_dataframe(original_DT_ref)
self.DT = _prepare_boot_data(self, tmp, i)
del tmp
if self._offloader.enabled:
del tmp
self.bootstrap_nboot = 0
boot_fit = method(self, *args, **kwargs)
results.append(boot_fit)
Expand Down
31 changes: 30 additions & 1 deletion pySEQTarget/helpers/_offloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import lru_cache
from pathlib import Path
from typing import Any, Optional, Union

Expand All @@ -12,6 +13,25 @@ def __init__(self, enabled: bool, dir: str, compression: int = 3):
self.enabled = enabled
self.dir = Path(dir)
self.compression = compression
# Create a cached loader bound to this instance
self._init_cache()

def _init_cache(self):
"""Initialize the LRU cache for model loading."""
self._cached_load = lru_cache(maxsize=32)(self._load_from_disk)

def __getstate__(self):
"""Prepare state for pickling - exclude the unpicklable cache."""
state = self.__dict__.copy()
# Remove the cache wrapper which can't be pickled
del state['_cached_load']
return state

def __setstate__(self, state):
"""Restore state after unpickling - recreate the cache."""
self.__dict__.update(state)
# Recreate the cache after unpickling
self._init_cache()

def save_model(
self, model: Any, name: str, boot_idx: Optional[int] = None
Expand All @@ -29,11 +49,20 @@ def save_model(

return str(filepath)

def _load_from_disk(self, filepath: str) -> Any:
"""Internal method to load a model from disk (cached)."""
return joblib.load(filepath)

def load_model(self, ref: Union[Any, str]) -> Any:
"""Load a model, using cache for repeated loads of the same file."""
if not self.enabled or not isinstance(ref, str):
return ref

return joblib.load(ref)
return self._cached_load(ref)

def clear_cache(self) -> None:
"""Clear the model loading cache. Call between bootstrap iterations if needed."""
self._cached_load.cache_clear()

def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]:
if not self.enabled:
Expand Down
Loading