Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions pySEQ/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SEQopts:
followup_include: bool = True
followup_max: int = None
followup_min: int = 0
followup_spline: bool = False #TODO
followup_spline: bool = False
hazard: bool = False # TODO
indicator_baseline: str = "_bas"
indicator_squared: str = "_sq"
Expand All @@ -37,7 +37,7 @@ class SEQopts:
selection_first_trial: bool = False
selection_probability: float = 0.8
selection_random: bool = False
subgroup_colname: str = None #TODO
subgroup_colname: str = None
treatment_level: List[int] = field(default_factory=lambda: [0, 1])
trial_include: bool = True
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
Expand All @@ -50,11 +50,11 @@ class SEQopts:

def __post_init__(self):
bools = [
'excused', 'followup_class', 'followup_include',
'followup_spline', 'hazard', 'km_curves',
'parallel', 'selection_first_trial', 'selection_random',
'trial_include', 'weight_lag_condition', 'weight_p99',
'weight_preexpansion', 'weighted'
"excused", "followup_class", "followup_include",
"followup_spline", "hazard", "km_curves",
"parallel", "selection_first_trial", "selection_random",
"trial_include", "weight_lag_condition", "weight_p99",
"weight_preexpansion", "weighted"
]
for i in bools:
if not isinstance(getattr(self, i), bool):
Expand All @@ -73,18 +73,14 @@ def __post_init__(self):
if not (0.0 <= self.selection_probability <= 1.0):
raise ValueError("selection_probability must be between 0 and 1.")

if self.plot_type not in ['risk', 'survival']:
if self.plot_type not in ["risk", "survival"]:
raise ValueError("plot_type must be either 'risk' or 'survival'.")

if self.bootstrap_CI_method not in ['se', 'percentile']:
if self.bootstrap_CI_method not in ["se", "percentile"]:
raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'")

lists = [

]

for i in ('covariates', 'numerator', 'denominator',
'cense_numerator', 'cense_denominator'):
for i in ("covariates", "numerator", "denominator",
"cense_numerator", "cense_denominator"):
attr = getattr(self, i)
if attr is not None and not isinstance(attr, list):
setattr(self, i, "".join(attr.split()))
8 changes: 6 additions & 2 deletions pySEQ/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator
from .expansion import _mapper, _binder, _dynamic, _random_selection
from .weighting import _weight_setup, _fit_LTFU, _fit_numerator, _fit_denominator, _weight_bind, _weight_predict, _weight_stats
from .analysis import _outcome_fit, _calculate_risk, _calculate_survival
from .analysis import _outcome_fit, _pred_risk, _calculate_survival, _subgroup_fit
from .plot import _survival_plot


Expand Down Expand Up @@ -164,6 +164,10 @@ def fit(self):

end = time.perf_counter()
self.model_time = _format_time(start, end)

if self.subgroup_colname is not None:
return _subgroup_fit(self)

return _outcome_fit(self,
self.DT,
self.outcome_col,
Expand All @@ -177,7 +181,7 @@ def survival(self):

start = time.perf_counter()

risk_data = _calculate_risk(self)
risk_data = _pred_risk(self)
surv_data = _calculate_survival(self, risk_data)
self.km_data = pl.concat([risk_data, surv_data])

Expand Down
3 changes: 2 additions & 1 deletion pySEQ/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ._outcome_fit import _outcome_fit
from ._survival_pred import _get_outcome_predictions, _calculate_risk, _calculate_survival
from ._survival_pred import _get_outcome_predictions, _pred_risk, _calculate_survival
from ._subgroup_fit import _subgroup_fit
14 changes: 7 additions & 7 deletions pySEQ/analysis/_outcome_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def _outcome_fit(
if self.followup_spline:
spline = f"cr(followup, df=3)"

formula = re.sub(r'(\w+)\s*\*\s*followup\b', rf'\1*{spline}', formula)
formula = re.sub(r'\bfollowup\s*\*\s*(\w+)', rf'{spline}*\1', formula)
formula = re.sub(rf'\bfollowup{re.escape(self.indicator_squared)}\b', '', formula)
formula = re.sub(r'\bfollowup\b', '', formula)
formula = re.sub(r"(\w+)\s*\*\s*followup\b", rf"\1*{spline}", formula)
formula = re.sub(r"\bfollowup\s*\*\s*(\w+)", rf"{spline}*\1", formula)
formula = re.sub(rf"\bfollowup{re.escape(self.indicator_squared)}\b", "", formula)
formula = re.sub(r"\bfollowup\b", "", formula)

formula = re.sub(r'\s+', ' ', formula)
formula = re.sub(r'\+\s*\+', '+', formula)
formula = re.sub(r'^\s*\+\s*|\s*\+\s*$', '', formula).strip()
formula = re.sub(r"\s+", " ", formula)
formula = re.sub(r"\+\s*\+", "+", formula)
formula = re.sub(r"^\s*\+\s*|\s*\+\s*$", "", formula).strip()

if formula:
formula = f"{formula} + I({spline}**2)"
Expand Down
20 changes: 20 additions & 0 deletions pySEQ/analysis/_subgroup_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import polars as pl
from ._outcome_fit import _outcome_fit

def _subgroup_fit(self):
subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list())
self._unique_subgroups = subgroups

models = []
for val in subgroups:
subDT = self.DT.filter(pl.col(self.subgroup_colname) == val)

model = _outcome_fit(self,
subDT,
self.outcome_col,
self.covariates,
self.weighted,
"weight")
models.append(model)

return models
157 changes: 72 additions & 85 deletions pySEQ/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,44 @@
import concurrent.futures
import polars as pl
import numpy as np
from ..helpers import _predict_model

def _get_outcome_predictions(self, newdata):
if self.compevent_colname is not None:
pass
def _get_outcome_predictions(self, TxDT, idx=None):
data = TxDT.to_pandas()
predictions = []
for boot_model in self.outcome_model:
model = boot_model[idx] if idx is not None else boot_model
pred = model.predict(data)
predictions.append(pred)
return predictions

def _pred_risk(self):
has_subgroups = (isinstance(self.outcome_model[0], list) if self.outcome_model else False)

if not has_subgroups:
return _calculate_risk(self, self.DT, idx=None, val=None)

all_risks = []
original_DT = self.DT

if self.parallel and self.ncores > 1:
with concurrent.futures.ProcessPoolExecutor(max_workers=self.ncores) as executor:
preds = list(executor.map(_predict_model, self, self.outcome_model, newdata))
else:
preds = [_predict_model(self, model, newdata) for model in self.outcome_model]
for i, val in enumerate(self._unique_subgroups):
subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val)
risk = _calculate_risk(self, subgroup_DT, i, val)
all_risks.append(risk)

return preds
self.DT = original_DT
return pl.concat(all_risks)

def _calculate_risk(self):
def _calculate_risk(self, data, idx=None, val=None):
a = 1 - self.bootstrap_CI
lci = a / 2
uci = 1 - lci

if self.followup_max is None:
self.followup_max = self.DT.select(pl.col("followup").max()).to_numpy()[0][0]

SDT = (
self.DT
.with_columns([
(pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)).alias("TID")
])
.group_by("TID")
.first()
data
.with_columns([(pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)).alias("TID")])
.group_by("TID").first()
.drop(["followup", f"followup{self.indicator_squared}"])
.with_columns([
pl.lit(list(range(self.followup_max))).alias("followup")
])
.with_columns([pl.lit(list(range(self.followup_max))).alias("followup")])
.explode("followup")
.with_columns([
(pl.col("followup") + 1).alias("followup"),
Expand All @@ -42,100 +47,82 @@ def _calculate_risk(self):
).sort([self.id_col, "trial", "followup"])

risks = []
for i in self.treatment_level:
TxDT = SDT.with_columns([
pl.lit(i).alias(f"{self.treatment_col}{self.indicator_baseline}")
])
for treatment_val in self.treatment_level:
TxDT = SDT.with_columns([pl.lit(treatment_val).alias(f"{self.treatment_col}{self.indicator_baseline}")])

if self.method == "dose-response":
pass
print("REMEMBER TO FIX DOSE-REPONSE RISK CALCULATION TODO")
if self.compevent_colname is not None:
pass

preds = _get_outcome_predictions(self, TxDT)
full = [pl.Series("pred_risk", preds[0])]
preds = _get_outcome_predictions(self, TxDT, idx=idx)
pred_series = [pl.Series("pred_risk", preds[0])]

if self.bootstrap_nboot > 0:
for idx, pred in enumerate(preds[1:], start=1):
full.append(pl.Series(f"pred_risk_{idx}", pred))
for boot_idx, pred in enumerate(preds[1:], start=1):
pred_series.append(pl.Series(f"pred_risk_{boot_idx}", pred))

names = [col.name for col in full]
names = [s.name for s in pred_series]

TxDT = TxDT.with_columns(full).with_columns([
(1 - pl.col(col)).cum_prod().over("TID").alias(col) for col in names
]).group_by("followup").agg([
pl.col(col).mean() for col in names
]).sort("followup").with_columns([
(1 - pl.col(col)).alias(col) for col in names
])
TxDT = (
TxDT.with_columns(pred_series)
.with_columns([(1 - pl.col(col)).cum_prod().over("TID").alias(col) for col in names])
.group_by("followup").agg([pl.col(col).mean() for col in names])
.sort("followup")
.with_columns([(1 - pl.col(col)).alias(col) for col in names])
)

boots = [col for col in names if col != "pred_risk"]
boot_cols = [col for col in names if col != "pred_risk"]

if len(boots) > 0:
risk = TxDT.select(["followup"] + boots).unpivot(
index="followup",
on=boots,
variable_name="bootID",
value_name="risk"
).group_by("followup").agg([
pl.col("risk").std().alias("SE"),
pl.col("risk").quantile(lci).alias("LCI"),
pl.col("risk").quantile(uci).alias("UCI")
])
risk = risk.join(TxDT.select(["followup", "pred_risk"]), on="followup")

if boot_cols:
risk = (
TxDT.select(["followup"] + boot_cols)
.unpivot(index="followup", on=boot_cols, variable_name="bootID", value_name="risk")
.group_by("followup").agg([
pl.col("risk").std().alias("SE"),
pl.col("risk").quantile(lci).alias("LCI"),
pl.col("risk").quantile(uci).alias("UCI")
])
.join(TxDT.select(["followup", "pred_risk"]), on="followup")
)

if self.bootstrap_CI_method == "se":
from scipy.stats import norm
z = norm.ppf(1 - a / 2)
risk = risk.with_columns([
(pl.col("pred_risk") - z * pl.col("SE")).alias("LCI"),
(pl.col("pred_risk") + z * pl.col("SE")).alias("UCI")
])
risk = risk.with_columns(pl.lit(i).alias(self.treatment_col)).select([

risk = risk.select([
"followup",
self.treatment_col,
"pred_risk",
"SE",
"LCI",
"UCI"
pl.lit(treatment_val).alias(self.treatment_col),
"pred_risk", "SE", "LCI", "UCI"
])

fup0 = pl.DataFrame({
"followup": [0],
self.treatment_col: [i],
"pred_risk": [0.0],
"SE": [0.0],
"LCI": [0.0],
"UCI": [0.0]
"followup": [0], self.treatment_col: [treatment_val],
"pred_risk": [0.0], "SE": [0.0], "LCI": [0.0], "UCI": [0.0]
}).with_columns([
pl.col("followup").cast(pl.Int64),
pl.col(self.treatment_col).cast(pl.Int32)
])
else:
risk = TxDT.select(["followup", "pred_risk"]).sort("followup").with_columns([
pl.lit(i).alias(self.treatment_col)
]).select([
"followup",
self.treatment_col,
"pred_risk"
])

fup0 = pl.DataFrame({
"followup": [0],
self.treatment_col: [i],
"pred_risk": [0.0]
}).with_columns([
risk = TxDT.select(["followup", pl.lit(treatment_val).alias(self.treatment_col), "pred_risk"])
fup0 = pl.DataFrame({"followup": [0],
self.treatment_col: [treatment_val], "pred_risk": [0.0]}
).with_columns([
pl.col("followup").cast(pl.Int64),
pl.col(self.treatment_col).cast(pl.Int32)
])

risk = pl.concat([fup0, risk])
risks.append(risk)
risks.append(pl.concat([fup0, risk]))

out = pl.concat(risks).with_columns(pl.lit("risk").alias("estimate"))
if val is not None:
out = out.with_columns(pl.lit(val).alias(self.subgroup_colname))

out = pl.concat(risks) \
.with_columns(pl.lit("risk").alias("estimate")) \
.rename({"pred_risk": "pred"})
return out
return out.rename({"pred_risk": "pred"})


def _calculate_survival(self, risk_data):
Expand Down
2 changes: 1 addition & 1 deletion pySEQ/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ def load_data(name: str = "SEQdata") -> pl.DataFrame:
data_path = loc.joinpath("SEQdata_LTFU.csv")
return pl.read_csv(data_path)
else:
raise ValueError(f"Dataset '{name}' not available. Options: ['SEQdata']")
raise ValueError(f"Dataset '{name}' not available. Options: ['SEQdata', 'SEQdata_multitreatment', 'SEQdata_LTFU']")
10 changes: 5 additions & 5 deletions pySEQ/expansion/_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def _binder(self, kept_cols):
DT = _mapper(self.data, self.id_col, self.time_col, self.followup_min, self.followup_max)
DT = DT.join(
self.data.select([self.id_col, self.time_col] + kept),
left_on=[self.id_col, 'period'],
left_on=[self.id_col, "period"],
right_on=[self.id_col, self.time_col],
how='left'
how="left"
)
DT = DT.sort([self.id_col, "trial", "followup"]) \
.with_columns([
Expand All @@ -52,16 +52,16 @@ def _binder(self, kept_cols):
if squared:
squares = []
for sq in squared:
col = sq.replace(self.indicator_squared, '')
col = sq.replace(self.indicator_squared, "")
squares.append((pl.col(col) ** 2).alias(f"{col}{self.indicator_squared}"))
DT = DT.with_columns(squares)

baseline_cols = {bas.replace(self.indicator_baseline, '') for bas in baseline}
baseline_cols = {bas.replace(self.indicator_baseline, "") for bas in baseline}
needed = {self.eligible_col, self.treatment_col}
baseline_cols.update({c for c in needed})

bas = [
pl.col(c).first().over([self.id_col, 'trial']).alias(f"{c}{self.indicator_baseline}")
pl.col(c).first().over([self.id_col, "trial"]).alias(f"{c}{self.indicator_baseline}")
for c in baseline_cols
]

Expand Down
Loading