Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2d2a01f
fix to compevent models
ryan-odea Dec 25, 2025
30c14db
prepare gitignore
ryan-odea Dec 26, 2025
2ca636a
add joblib
ryan-odea Dec 26, 2025
e5a8fe7
add offload (and docs for visit)
ryan-odea Dec 26, 2025
5c99c93
create offloader
ryan-odea Dec 26, 2025
ac2e88d
setup offloading in primary API
ryan-odea Dec 26, 2025
2239d41
offloader to init
ryan-odea Dec 26, 2025
2bf49ec
add boot idx
ryan-odea Dec 26, 2025
1f4e23c
add intake from unloaded models
ryan-odea Dec 26, 2025
c8af80a
setup weight offloading
ryan-odea Dec 26, 2025
46fcf7a
weight offload to init
ryan-odea Dec 26, 2025
c589967
add weight offloading to SEQ + weight inloading
ryan-odea Dec 26, 2025
9974721
test offload
ryan-odea Dec 26, 2025
5bb0d23
bump version
ryan-odea Dec 26, 2025
24581ad
Merge branch 'main' into compevent-fix
ryan-odea Dec 26, 2025
953ca85
setup offloading for dataframes
ryan-odea Dec 28, 2025
9f1827e
offload original DT while bootstrapping
ryan-odea Dec 28, 2025
f47d738
adjust test nboot
ryan-odea Dec 28, 2025
904b354
skipping compevent tests
ryan-odea Dec 28, 2025
9045ad7
formatted
ryan-odea Dec 30, 2025
a27a358
Use smf.logit() for binary treatment vars
remlapmot Jan 1, 2026
96fb1f5
Handle intercept-only formula when numerator is "1" or empty
remlapmot Jan 1, 2026
0315a60
Allow specifying fitting method
remlapmot Jan 1, 2026
90dca14
Obtain expected category levels from fitted model
remlapmot Jan 1, 2026
360dd6a
Improve handling of categories for predictions
remlapmot Jan 1, 2026
8d3aba0
Account for NaNs in predicted probs
remlapmot Jan 2, 2026
27ffd78
Make survival preds safe
remlapmot Jan 2, 2026
c211376
Move _safe_predict into a helper file
remlapmot Jan 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,6 @@ cython_debug/

# uv lock file
uv.lock

# offloaded data files (offload test)
_seq_models/
19 changes: 17 additions & 2 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing
import os
from dataclasses import dataclass, field
from typing import List, Literal, Optional

Expand All @@ -18,7 +19,7 @@ class SEQopts:
:type bootstrap_CI_method: str
:param cense_colname: Column name for censoring effect (LTFU, etc.)
:type cense_colname: str
:param cense_denominator: Override to specify denominator patsy formula for censoring models
:param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model
:type cense_denominator: Optional[str] or None
:param cense_numerator: Override to specify numerator patsy formula for censoring models
:type cense_numerator: Optional[str] or None
Expand Down Expand Up @@ -54,8 +55,12 @@ class SEQopts:
:type km_curves: bool
:param ncores: Number of cores to use if running in parallel
:type ncores: int
:param numerator: Override to specify the outcome patsy formula for numerator models
:param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model
:type numerator: str
:param offload: Boolean to offload intermediate model data to disk
:type offload: bool
:param offload_dir: Directory to offload intermediate model data
:type offload_dir: str
:param parallel: Boolean to run model fitting in parallel
:type parallel: bool
:param plot_colors: List of colors for KM plots, if applicable
Expand All @@ -80,8 +85,12 @@ class SEQopts:
:type treatment_level: List[int]
:param trial_include: Boolean to force trial values into model covariates
:type trial_include: bool
:param visit_colname: Column name specifying visit number
:type visit_colname: str
:param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting
:type weight_eligible_colnames: List[str]
:param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton"
:type weight_fit_method: str
:param weight_min: Minimum weight
:type weight_min: float
:param weight_max: Maximum weight
Expand Down Expand Up @@ -120,6 +129,8 @@ class SEQopts:
km_curves: bool = False
ncores: int = multiprocessing.cpu_count()
numerator: Optional[str] = None
offload: bool = False
offload_dir: str = "_seq_models"
parallel: bool = False
plot_colors: List[str] = field(
default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]
Expand All @@ -136,6 +147,7 @@ class SEQopts:
trial_include: bool = True
visit_colname: str = None
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
weight_fit_method: Literal["newton", "bfgs", "lbfgs", "nm"] = "newton"
weight_min: float = 0.0
weight_max: float = None
weight_lag_condition: bool = True
Expand Down Expand Up @@ -195,3 +207,6 @@ def __post_init__(self):
attr = getattr(self, i)
if attr is not None and not isinstance(attr, list):
setattr(self, i, "".join(attr.split()))

if self.offload:
os.makedirs(self.offload_dir, exist_ok=True)
19 changes: 16 additions & 3 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
_subgroup_fit)
from .error import _data_checker, _param_checker
from .expansion import _binder, _diagnostics, _dynamic, _random_selection
from .helpers import _col_string, _format_time, bootstrap_loop
from .helpers import Offloader, _col_string, _format_time, bootstrap_loop
from .initialization import (_cense_denominator, _cense_numerator,
_denominator, _numerator, _outcome)
from .plot import _survival_plot
from .SEQopts import SEQopts
from .SEQoutput import SEQoutput
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
_fit_visit, _weight_bind, _weight_predict,
_weight_setup, _weight_stats)
_fit_visit, _offload_weights, _weight_bind,
_weight_predict, _weight_setup, _weight_stats)


class SEQuential:
Expand Down Expand Up @@ -84,6 +84,8 @@ def __init__(
np.random.RandomState(self.seed) if self.seed is not None else np.random
)

self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir)

if self.covariates is None:
self.covariates = _outcome(self)

Expand Down Expand Up @@ -201,6 +203,9 @@ def fit(self) -> None:
raise ValueError(
"Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
)
boot_idx = None
if hasattr(self, "_current_boot_idx"):
boot_idx = self._current_boot_idx

if self.weighted:
WDT = _weight_setup(self)
Expand All @@ -217,6 +222,9 @@ def fit(self) -> None:
_fit_numerator(self, WDT)
_fit_denominator(self, WDT)

if self.offload:
_offload_weights(self, boot_idx)

WDT = pl.from_pandas(WDT)
WDT = _weight_predict(self, WDT)
_weight_bind(self, WDT)
Expand Down Expand Up @@ -244,6 +252,11 @@ def fit(self) -> None:
self.weighted,
"weight",
)
if self.offload:
offloaded_models = {}
for key, model in models.items():
offloaded_models[key] = self._offloader.save_model(model, key, boot_idx)
return offloaded_models
return models

def survival(self, **kwargs) -> None:
Expand Down
13 changes: 9 additions & 4 deletions pySEQTarget/analysis/_hazard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import polars as pl
from lifelines import CoxPHFitter

from ..helpers._predict_model import _safe_predict


def _calculate_hazard(self):
if self.subgroup_colname is None:
Expand Down Expand Up @@ -93,8 +95,10 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
else:
model_dict = self.outcome_model[boot_idx]

outcome_model = model_dict["outcome"]
ce_model = model_dict.get("compevent", None) if self.compevent_colname else None
outcome_model = self._offloader.load_model(model_dict["outcome"])
ce_model = None
if self.compevent_colname and "compevent" in model_dict:
ce_model = self._offloader.load_model(model_dict["compevent"])

all_treatments = []
for val in self.treatment_level:
Expand All @@ -103,13 +107,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
)

tmp_pd = tmp.to_pandas()
outcome_prob = outcome_model.predict(tmp_pd)
outcome_prob = _safe_predict(outcome_model, tmp_pd)
outcome_sim = rng.binomial(1, outcome_prob)

tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])

if ce_model is not None:
ce_prob = ce_model.predict(tmp_pd)
ce_tmp_pd = tmp.to_pandas()
ce_prob = _safe_predict(ce_model, ce_tmp_pd)
ce_sim = rng.binomial(1, ce_prob)
tmp = tmp.with_columns([pl.Series("ce", ce_sim)])

Expand Down
9 changes: 7 additions & 2 deletions pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import polars as pl

from ..helpers._predict_model import _safe_predict


def _get_outcome_predictions(self, TxDT, idx=None):
data = TxDT.to_pandas()
Expand All @@ -9,9 +11,12 @@ def _get_outcome_predictions(self, TxDT, idx=None):

for boot_model in self.outcome_model:
model_dict = boot_model[idx] if idx is not None else boot_model
predictions["outcome"].append(model_dict["outcome"].predict(data))
outcome_model = self._offloader.load_model(model_dict["outcome"])
predictions["outcome"].append(_safe_predict(outcome_model, data.copy()))

if self.compevent_colname is not None:
predictions["compevent"].append(model_dict["compevent"].predict(data))
compevent_model = self._offloader.load_model(model_dict["compevent"])
predictions["compevent"].append(_safe_predict(compevent_model, data.copy()))

return predictions

Expand Down
1 change: 1 addition & 0 deletions pySEQTarget/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._bootstrap import bootstrap_loop as bootstrap_loop
from ._col_string import _col_string as _col_string
from ._format_time import _format_time as _format_time
from ._offloader import Offloader as Offloader
from ._output_files import _build_md as _build_md
from ._output_files import _build_pdf as _build_pdf
from ._pad import _pad as _pad
Expand Down
21 changes: 18 additions & 3 deletions pySEQTarget/helpers/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs):
obj._rng = (
np.random.RandomState(seed + i) if seed is not None else np.random.RandomState()
)
original_DT = obj._offloader.load_dataframe(original_DT)
obj.DT = _prepare_boot_data(obj, original_DT, i)
del original_DT
obj._current_boot_idx = i + 1

# Disable bootstrapping to prevent recursion
obj.bootstrap_nboot = 0
Expand All @@ -60,6 +63,7 @@ def wrapper(self, *args, **kwargs):
results = []
original_DT = self.DT

self._current_boot_idx = None
full = method(self, *args, **kwargs)
results.append(full)

Expand All @@ -71,17 +75,20 @@ def wrapper(self, *args, **kwargs):
seed = getattr(self, "seed", None)
method_name = method.__name__

original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")

if getattr(self, "parallel", False):
original_rng = getattr(self, "_rng", None)
self._rng = None
self.DT = None

with ProcessPoolExecutor(max_workers=ncores) as executor:
futures = [
executor.submit(
_bootstrap_worker,
self,
method_name,
original_DT,
original_DT_ref,
i,
seed,
args,
Expand All @@ -95,13 +102,21 @@ def wrapper(self, *args, **kwargs):
results.append(j.result())

self._rng = original_rng
self.DT = self._offloader.load_dataframe(original_DT_ref)
else:
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
del original_DT
for i in tqdm(range(nboot), desc="Bootstrapping..."):
self.DT = _prepare_boot_data(self, original_DT, i)
self._current_boot_idx = i + 1
tmp = self._offloader.load_dataframe(original_DT_ref)
self.DT = _prepare_boot_data(self, tmp, i)
del tmp
self.bootstrap_nboot = 0
boot_fit = method(self, *args, **kwargs)
results.append(boot_fit)

self.DT = original_DT
self.bootstrap_nboot = nboot
self.DT = self._offloader.load_dataframe(original_DT_ref)

end = time.perf_counter()
self._model_time = _format_time(start, end)
Expand Down
15 changes: 15 additions & 0 deletions pySEQTarget/helpers/_fix_categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
def _fix_categories_for_predict(model, newdata):
"""
Fix categorical column ordering in newdata to match what the model expects.
"""
if hasattr(model, 'model') and hasattr(model.model, 'data') and hasattr(model.model.data, 'design_info'):
design_info = model.model.data.design_info
for factor, factor_info in design_info.factor_infos.items():
if factor_info.type == 'categorical':
col_name = factor.name()
if col_name in newdata.columns:
expected_categories = list(factor_info.categories)
newdata[col_name] = newdata[col_name].astype(str)
newdata[col_name] = newdata[col_name].astype('category')
newdata[col_name] = newdata[col_name].cat.set_categories(expected_categories)
return newdata
53 changes: 53 additions & 0 deletions pySEQTarget/helpers/_offloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pathlib import Path
from typing import Any, Optional, Union

import joblib
import polars as pl


class Offloader:
"""Manages disk-based storage for models and intermediate data"""

def __init__(self, enabled: bool, dir: str, compression: int = 3):
self.enabled = enabled
self.dir = Path(dir)
self.compression = compression

def save_model(
self, model: Any, name: str, boot_idx: Optional[int] = None
) -> Union[Any, str]:
"""Save a fitted model to disk and return a reference"""
if not self.enabled:
return model

filename = (
f"{name}_boot{boot_idx}.pkl" if boot_idx is not None else f"{name}.pkl"
)
filepath = self.dir / filename

joblib.dump(model, filepath, compress=self.compression)

return str(filepath)

def load_model(self, ref: Union[Any, str]) -> Any:
if not self.enabled or not isinstance(ref, str):
return ref

return joblib.load(ref)

def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]:
if not self.enabled:
return df

filename = f"{name}.parquet"
filepath = self.dir / filename

df.write_parquet(filepath, compression="zstd")

return str(filepath)

def load_dataframe(self, ref: Union[pl.DataFrame, str]) -> pl.DataFrame:
if not self.enabled or not isinstance(ref, str):
return ref

return pl.read_parquet(ref)
50 changes: 49 additions & 1 deletion pySEQTarget/helpers/_predict_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,57 @@
import warnings

import numpy as np

from ._fix_categories import _fix_categories_for_predict


def _safe_predict(model, data, clip_probs=True):
"""
Predict with category fix fallback if needed.

Parameters
----------
model : statsmodels model
Fitted model object
data : pandas DataFrame
Data to predict on
clip_probs : bool
If True, clip probabilities to [0, 1] and replace NaN with 0.5
"""
data = data.copy()

try:
probs = model.predict(data)
except Exception as e:
if "mismatching levels" in str(e):
data = _fix_categories_for_predict(model, data)
probs = model.predict(data)
else:
raise

if clip_probs:
probs = np.array(probs)
if np.any(np.isnan(probs)):
warnings.warn("NaN values in predicted probabilities, replacing with 0.5")
probs = np.where(np.isnan(probs), 0.5, probs)
probs = np.clip(probs, 0, 1)

return probs


def _predict_model(self, model, newdata):
newdata = newdata.to_pandas()

# Original behavior - convert fixed_cols to category
for col in self.fixed_cols:
if col in newdata.columns:
newdata[col] = newdata[col].astype("category")
return np.array(model.predict(newdata))

try:
return np.array(model.predict(newdata))
except Exception as e:
if "mismatching levels" in str(e):
newdata = _fix_categories_for_predict(model, newdata)
return np.array(model.predict(newdata))
else:
raise
Loading
Loading