Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pySEQ/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class SEQopts:
cense_denominator: Optional[str] = None
cense_numerator: Optional[str] = None
cense_eligible_colname: Optional[str] = None
compevent_colname: Optional[str] = None #TODO
compevent_colname: Optional[str] = None
covariates: Optional[str] = None
denominator: Optional[str] = None
excused: bool = False
Expand All @@ -22,7 +22,7 @@ class SEQopts:
followup_max: int = None
followup_min: int = 0
followup_spline: bool = False
hazard: bool = False # TODO
hazard_estimate: bool = False
indicator_baseline: str = "_bas"
indicator_squared: str = "_sq"
km_curves: bool = False
Expand All @@ -32,7 +32,7 @@ class SEQopts:
plot_colors: List[str] = field(default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"])
plot_labels: List[str] = field(default_factory=lambda: [])
plot_title: str = None
plot_type: Literal["risk", "survival", "inc"] = "risk" # add inc (compevent)
plot_type: Literal["risk", "survival", "incidence"] = "risk"
seed: Optional[int] = None
selection_first_trial: bool = False
selection_probability: float = 0.8
Expand All @@ -51,7 +51,7 @@ class SEQopts:
def __post_init__(self):
bools = [
"excused", "followup_class", "followup_include",
"followup_spline", "hazard", "km_curves",
"followup_spline", "hazard_estimate", "km_curves",
"parallel", "selection_first_trial", "selection_random",
"trial_include", "weight_lag_condition", "weight_p99",
"weight_preexpansion", "weighted"
Expand All @@ -73,8 +73,8 @@ def __post_init__(self):
if not (0.0 <= self.selection_probability <= 1.0):
raise ValueError("selection_probability must be between 0 and 1.")

if self.plot_type not in ["risk", "survival"]:
raise ValueError("plot_type must be either 'risk' or 'survival'.")
if self.plot_type not in ["risk", "survival", "incidence"]:
raise ValueError("plot_type must be either 'risk', 'survival', or 'incidence'.")

if self.bootstrap_CI_method not in ["se", "percentile"]:
raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'")
Expand Down
84 changes: 84 additions & 0 deletions pySEQ/SEQoutput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from dataclasses import dataclass
from typing import List, Optional, Literal
from .SEQopts import SEQopts
import statsmodels.formula.api as smf
import polars as pl
import matplotlib.figure

@dataclass
class SEQoutput:
options: SEQopts = None
method: str = None
numerator_models: List[smf.MNLogit] = None
denominator_models: List[smf.MNLogit] = None
outcome_models: List[List[smf.glm]] = None
compevent_models: List[List[smf.glm]] = None
weight_statistics: dict = None
hazard: pl.DataFrame = None
km_data: pl.DataFrame = None
km_graph: matplotlib.figure.Figure = None
risk_ratio: pl.DataFrame = None
risk_difference: pl.DataFrame = None
time: dict = None
diagnostic_tables: dict = None

def plot(self):
print(self.km_graph)

def summary(self,
type = Optional[Literal[
"numerator",
"denominator",
"outcome",
"compevent"]]):
match type:
case "numerator":
models = self.numerator_models
case "denominator":
models = self.denominator_models
case "compevent":
models = self.compevent_models
case _:
models = self.outcome_models

return [model.summary() for model in models]

def retrieve_data(self,
type = Optional[Literal[
"km_data",
"hazard",
"risk_ratio",
"risk_difference",
"unique_outcomes",
"nonunique_outcomes",
"unique_switches",
"nonunique_switches"
]]):
match type:
case "hazard":
data = self.hazard
case "risk_ratio":
data = self.risk_ratio
case "risk_difference":
data = self.risk_difference
case "unique_outcomes":
data = self.diagnostic_tables["unique_outcomes"]
case "nonunique_outcomes":
data = self.diagnostic_tables["nonunique_outcomes"]
case "unique_switches":
if self.diagnostic_tables.has_key("unique_switches"):
data = self.diagnostic_tables["unique_switches"]
else:
data = None
case "nonunique_switches":
if self.diagnostic_tables.has_key("nonunique_switches"):
data = self.diagnostic_tables["nonunique_switches"]
else:
data = None
case _:
data = self.km_data
if data is None:
ValueError("Data {type} was not created in the SEQuential process")
return data


112 changes: 95 additions & 17 deletions pySEQ/SEQuential.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import Optional, List, Literal
import sys
import time
from dataclasses import asdict
from collections import Counter
import polars as pl
import numpy as np
import datetime

from .SEQopts import SEQopts
from .SEQoutput import SEQoutput
from .error import _param_checker
from .helpers import _col_string, bootstrap_loop, _format_time
from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator
from .expansion import _mapper, _binder, _dynamic, _random_selection
from .expansion import _binder, _dynamic, _random_selection, _diagnostics
from .weighting import _weight_setup, _fit_LTFU, _fit_numerator, _fit_denominator, _weight_bind, _weight_predict, _weight_stats
from .analysis import _outcome_fit, _pred_risk, _calculate_survival, _subgroup_fit
from .analysis import _outcome_fit, _pred_risk, _calculate_survival, _subgroup_fit, _calculate_hazard, _risk_estimates
from .plot import _survival_plot


Expand Down Expand Up @@ -40,11 +41,15 @@ def __init__(
self.fixed_cols = fixed_cols
self.method = method

self._time_initialized = datetime.datetime.now()

if parameters is None:
parameters = SEQopts()

for name, value in asdict(parameters).items():
setattr(self, name, value)

self._rng = np.random.RandomState(self.seed) if self.seed is not None else np.random

if self.covariates is None:
self.covariates = _outcome(self)
Expand Down Expand Up @@ -113,9 +118,10 @@ def expand(self):
_dynamic(self)
if self.selection_random:
_random_selection(self)
_diagnostics(self)

end = time.perf_counter()
self.expansion_time = _format_time(start, end)
self._expansion_time = _format_time(start, end)

def bootstrap(self, **kwargs):
allowed = {"bootstrap_nboot", "bootstrap_sample",
Expand All @@ -126,7 +132,6 @@ def bootstrap(self, **kwargs):
else:
raise ValueError(f"Unknown argument: {key}")

self._rng = np.random.RandomState(self.seed) if self.seed is not None else np.random
UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list()
NIDs = len(UIDs)

Expand All @@ -142,7 +147,6 @@ def fit(self):
if self.bootstrap_nboot > 0 and not hasattr(self, "_boot_samples"):
raise ValueError("Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping.")

start = time.perf_counter()
if self.weighted:
WDT = _weight_setup(self)
if not self.weight_preexpansion and not self.excused:
Expand All @@ -162,18 +166,21 @@ def fit(self):
_weight_bind(self, WDT)
self.weight_stats = _weight_stats(self)

end = time.perf_counter()
self.model_time = _format_time(start, end)

if self.subgroup_colname is not None:
return _subgroup_fit(self)

return _outcome_fit(self,
self.DT,
self.outcome_col,
self.covariates,
self.weighted,
"weight")
models = {'outcome': _outcome_fit(self, self.DT,
self.outcome_col,
self.covariates,
self.weighted,
"weight")}
if self.compevent_colname is not None:
models['compevent'] = _outcome_fit(self, self.DT,
self.compevent_colname,
self.covariates,
self.weighted,
"weight")
return models

def survival(self):
if not hasattr(self, "outcome_model") or not self.outcome_model:
Expand All @@ -184,14 +191,85 @@ def survival(self):
risk_data = _pred_risk(self)
surv_data = _calculate_survival(self, risk_data)
self.km_data = pl.concat([risk_data, surv_data])
self.risk_estimates = _risk_estimates(self)

end = time.perf_counter()
self.survival_time = _format_time(start, end)
self._survival_time = _format_time(start, end)

def hazard(self):
start = time.perf_counter()

if not hasattr(self, "outcome_model") or not self.outcome_model:
raise ValueError("Outcome model not found. Please run the 'fit' method before calculating hazard ratio.")
self.hazard_ratio = _calculate_hazard(self)

end = time.perf_counter()
self._hazard_time = _format_time(start, end)

def plot(self):
self.km_graph = _survival_plot(self)
print(self.km_graph)

def collect(self):
self._time_collected = datetime.datetime.now()

generated = [
"numerator_model", "denominator_model",
"outcome_model",
"hazard_ratio", "risk_estimates",
"km_data", "diagnostics",
"_survival_time", "_hazard_time",
"_model_time", "_expansion_time",
"weight_stats"
]
for attr in generated:
if not hasattr(self, attr):
setattr(self, attr, None)

# Options ==========================
base = SEQopts()

for name, value in vars(self).items():
if name in asdict(base).keys():
setattr(base, name, value)

# Timing =========================
time = {"start_time": self._time_initialized,
"expansion_time": self._expansion_time,
"model_time": self._model_time,
"survival_time": self._survival_time,
"hazard_time": self._hazard_time,
"collection_time": self._time_collected}

if self.compevent_colname is not None:
compevent_models = [model["compevent"] for model in self.outcome_models]
else:
compevent_models = None

if self.outcome_model is not None:
outcome_models = [model["outcome"] for model in self.outcome_model]

if self.risk_estimates is None:
risk_ratio = risk_difference = None
else:
risk_ratio = self.risk_estimates["risk_ratio"]
risk_difference = self.risk_estimates["risk_difference"]

output = SEQoutput(
options=base,
method=self.method,
numerator_models=self.numerator_model,
denominator_models=self.denominator_model,
outcome_models=outcome_models,
compevent_models=compevent_models,
weight_statistics=self.weight_stats,
hazard=self.hazard_ratio,
km_data=self.km_data,
km_graph=self.km_graph,
risk_ratio=risk_ratio,
risk_difference=risk_difference,
time=time,
diagnostic_tables=self.diagnostics
)

return output

1 change: 1 addition & 0 deletions pySEQ/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .SEQuential import SEQuential
from .SEQopts import SEQopts
from .SEQoutput import SEQoutput

__all__ = [
"SEQuential",
Expand Down
4 changes: 3 additions & 1 deletion pySEQ/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._outcome_fit import _outcome_fit
from ._survival_pred import _get_outcome_predictions, _pred_risk, _calculate_survival
from ._subgroup_fit import _subgroup_fit
from ._subgroup_fit import _subgroup_fit
from ._hazard import _calculate_hazard
from ._risk_estimates import _risk_estimates
Loading