Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions pySEQ/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import numpy as np

from .SEQopts import SEQopts
from .helpers import _col_string, bootstrap_loop, _format_time
from .helpers import _col_string, bootstrap_loop, _format_time, _prepare_data
from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator
from .expansion import _mapper, _binder, _dynamic, _randomSelection
from .weighting import _weight_setup, _fit_LTFU, _fit_numerator, _fit_denominator, _weight_bind
from .weighting import _weight_setup, _fit_LTFU, _fit_numerator, _fit_denominator, _weight_bind, _weight_predict, _weight_stats
from .analysis import _outcome_fit, _calculate_risk, _calculate_survival
from .plot import _survival_plot

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
if self.denominator is None:
self.denominator = _denominator(self)

if self.cense is not None:
if self.cense_colname is not None:
if self.cense_numerator is None:
self.cense_numerator = _cense_numerator()

Expand All @@ -69,27 +69,45 @@ def expand(self):
self.subgroup_colname,
self.weight_eligible_colnames]

self.data = self.data.with_columns(
self.data = self.data.with_columns([
pl.when(pl.col(self.treatment_col).is_in(self.treatment_level))
.then(self.eligible_col)
.otherwise(0)
.alias(self.eligible_col)
)
.alias(self.eligible_col),
pl.col(self.treatment_col)
.shift(1)
.over([self.id_col])
.alias("tx_lag"),
pl.lit(False).alias("switch")
]).with_columns([
pl.when(pl.col(self.time_col) == 0)
.then(pl.lit(False))
.otherwise(
(pl.col("tx_lag").is_not_null()) &
(pl.col("tx_lag") != pl.col(self.treatment_col))
).cast(pl.Int8)
.alias("switch")
])

self.DT = _binder(_mapper(self.data, self.id_col, self.time_col), self.data,
self.id_col, self.time_col, self.eligible_col, self.outcome_col,
_col_string([self.covariates,
self.numerator, self.denominator,
self.cense_numerator, self.cense_denominator]).union(kept),
self.indicator_baseline, self.indicator_squared)
self.indicator_baseline, self.indicator_squared) \
.with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))
self.data = self.data.with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))

#self.data = _prepare_data(self, self.data)
#self.DT = _prepare_data(self, self.DT)

if self.method != "ITT":
self.DT = _dynamic(self.DT)
_dynamic(self)
if self.selection_random:
self.DT = _randomSelection(self.DT)
_randomSelection(self.DT)
end = time.perf_counter()
self.expansion_time = _format_time(start, end)

def bootstrap(self, **kwargs):
allowed = {"bootstrap_nboot", "bootstrap_sample",
"bootstrap_CI", "bootstrap_method"}
Expand Down Expand Up @@ -117,17 +135,26 @@ def fit(self):
WDT = _weight_setup(self)
if not self.weight_preexpansion and not self.excused:
WDT = WDT.filter(pl.col("followup") > 0)

WDT = WDT.to_pandas()
for col in self.fixed_cols:
if col in WDT.columns:
WDT[col] = WDT[col].astype("category")

_fit_LTFU(self, WDT)
_fit_numerator(self, WDT)
_fit_denominator(self, WDT)
self.DT = _weight_bind(self, WDT)

WDT = pl.from_pandas(WDT)
WDT = _weight_predict(self, WDT)
_weight_bind(self, WDT)

self.weight_stats = _weight_stats(self)

end = time.perf_counter()
self.model_time = _format_time(start, end)
return _outcome_fit(self.DT,
return _outcome_fit(self,
self.DT,
self.outcome_col,
self.covariates,
self.weighted,
Expand Down
2 changes: 1 addition & 1 deletion pySEQ/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._outcome_fit import _outcome_fit
from ._survival_pred import _get_predictions, _calculate_risk, _calculate_survival
from ._survival_pred import _get_outcome_predictions, _calculate_risk, _calculate_survival
12 changes: 11 additions & 1 deletion pySEQ/analysis/_outcome_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@
import statsmodels.formula.api as smf
import polars as pl

def _outcome_fit(
def _outcome_fit(self,
df: pl.DataFrame,
outcome: str,
formula: str,
weighted: bool = False,
weight_col: str = "weight",
):
if weighted:
df = self.DT.with_columns(
pl.col(weight_col).clip(
lower_bound=self.weight_min,
upper_bound=self.weight_max
)
)
df_pd = df.to_pandas()
for col in self.fixed_cols:
if col in df_pd.columns:
df_pd[col] = df_pd[col].astype("category")
formula = f"{outcome}~{formula}"

if weighted:
Expand Down
5 changes: 2 additions & 3 deletions pySEQ/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import concurrent.futures
import polars as pl
import numpy as np
from ._outcome_fit import _outcome_fit
from ..helpers import _predict_model

def _get_outcome_predictions(self, newdata):
Expand All @@ -10,9 +9,9 @@ def _get_outcome_predictions(self, newdata):

if self.parallel and self.ncores > 1:
with concurrent.futures.ProcessPoolExecutor(max_workers=self.ncores) as executor:
preds = list(executor.map(_predict_model, self.outcome_model, newdata))
preds = list(executor.map(_predict_model, self, self.outcome_model, newdata))
else:
preds = [_predict_model(model, newdata) for model in self.outcome_model]
preds = [_predict_model(self, model, newdata) for model in self.outcome_model]

return preds

Expand Down
4 changes: 3 additions & 1 deletion pySEQ/expansion/_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def _binder(DT, data, id_col, time_col, eligible_col, outcome_col, kept_cols,
f"followup{squared_indicator}",
"tx_lag",
"trial",
f"trial{squared_indicator}"}
f"trial{squared_indicator}",
time_col,
f"{time_col}{squared_indicator}"}

cols = kept_cols.union({eligible_col, outcome_col})
cols = {col for col in cols if col is not None}
Expand Down
79 changes: 54 additions & 25 deletions pySEQ/expansion/_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,61 @@
import polars as pl
def _dynamic(DT, id_col, time_col, treatment_col, method,
excused_col0, excused_col1,
squared_indicator, baseline_indicator):
def _dynamic(self):
"""
Handles special cases for the data from the __mapper -> __binder pipeline
"""
if method == "dose-response":
DT = DT.with_columns(
pl.col(treatment_col).cum_count().over([id_col, "trial"]).alias("dose")
if self.method == "dose-response":
DT = self.DT.with_columns(
pl.col(self.treatment_col)
.cum_count()
.over([self.id_col, "trial"])
.alias("dose")
).with_columns([
(pl.col("dose") ** 2).alias(f"dose{squared_indicator}")
(pl.col("dose") ** 2)
.alias(f"dose{self.squared_indicator}")
])
elif method == "censoring":
DT = DT.with_columns(
pl.col(treatment_col)
.first()
.over([id_col, "trial"])
.alias("temp")
self.DT = DT

elif self.method == "censoring":
DT = self.DT.sort([self.id_col, "trial", "followup"]).with_columns(
pl.col(self.treatment_col).shift(1).over([self.id_col, "trial"]).alias("tx_lag")
)

switch = (
pl.when(pl.col("followup") == 0)
.then(pl.lit(False))
.otherwise(
(pl.col("tx_lag").is_not_null()) &
(pl.col("tx_lag") != pl.col(self.treatment_col))
)
)

if self.excused:
conditions = []
for i in range(len(self.treatment_level)):
colname = self.excused_colnames[i]
if colname is not None:
conditions.append(
(pl.col(colname) == 1) &
(pl.col(self.treatment_col) == self.treatment_level[i])
)

if conditions:
excused = pl.any_horizontal(conditions)
switch = (
pl.when(pl.any_horizontal(conditions))
.then(pl.lit(False))
.otherwise(switch)
)

DT = DT.with_columns([
switch.alias("switch"),
excused.alias("isExcused") if self.excused else pl.lit(False).alias("isExcused")
]).filter(
pl.col("switch").cum_max().over([self.id_col, "trial"])
.shift(1, fill_value=False)
== 0
).with_columns(
pl.col(treatment_col)
.shift(1)
.over([id_col, "trial"])
.alias("tx_lag")
).with_columns(
pl.when(pl.col("temp").is_null())
.then(pl.col("temp"))
.otherwise(pl.col("tx_lag"))
.alias("tx_lag")
).with_columns(
(pl.col(treatment_col) != pl.col("tx_lag")).alias("switch")
)
pl.col("switch").cast(pl.Int8).alias("switch")
)

self.DT = DT.drop(["tx_lag"])
3 changes: 2 additions & 1 deletion pySEQ/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._col_string import _col_string
from ._bootstrap import bootstrap_loop
from ._format_time import _format_time
from ._predict_model import _predict_model
from ._predict_model import _predict_model
from ._prepare_data import _prepare_data
6 changes: 5 additions & 1 deletion pySEQ/helpers/_predict_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import numpy as np

def _predict_model(model, newdata):
def _predict_model(self, model, newdata):
newdata = newdata.to_pandas()
for col in self.fixed_cols:
if col in newdata.columns:
newdata[col] = newdata[col].astype("category")
return np.array(model.predict(newdata)).flatten()
14 changes: 14 additions & 0 deletions pySEQ/helpers/_prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import polars as pl

def _prepare_data(self, DT):
binaries = [self.eligible_col, self.outcome_col, self.cense_colname] # self.excused_colnames + self.weight_eligible_colnames
binary_colnames = [col for col in binaries if col in DT.columns and not None]

DT = DT.with_columns(
[
*[pl.col(col).cast(pl.Categorical) for col in self.fixed_cols],
*[pl.col(col).cast(pl.Int8) for col in binary_colnames],
pl.col(self.id_col).cast(pl.Utf8),
]
)
return DT
3 changes: 2 additions & 1 deletion pySEQ/weighting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._weight_fit import _fit_LTFU, _fit_numerator, _fit_denominator
from ._weight_pred import _weight_predict
from ._weight_bind import _weight_bind
from ._weight_data import _weight_setup
from ._weight_data import _weight_setup
from ._weight_stats import _weight_stats
46 changes: 44 additions & 2 deletions pySEQ/weighting/_weight_bind.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,44 @@
def _weight_bind():
pass
import polars as pl

def _weight_bind(self, WDT):
if self.weight_preexpansion:
join = "inner"
on = [self.id_col, "period"]
WDT = WDT.rename({self.time_col: "period"})
else:
join = "left"
on = [self.id_col, "trial", "followup"]

DT = self.DT.join(WDT,
on=on,
how=join)

if self.weight_preexpansion and self.excused:
trial = (pl.col("trial") == 0) & (pl.col("period") == 0)
excused = pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) > 0
else:
trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col))
excused = pl.lit(False)

override = (
trial |
excused |
pl.col(self.outcome_col).is_null() |
(pl.col("denominator") < 1e-15) |
pl.col("numerator").is_null()
)

self.DT = DT.with_columns(
pl.when(override)
.then(pl.lit(1.0))
.otherwise(pl.col("numerator") / pl.col("denominator"))
.alias("wt")
).sort(
[self.id_col, "trial", "followup"]
).with_columns(
pl.col("wt")
.fill_null(1.0)
.cum_prod()
.over([self.id_col, "trial"])
.alias("weight")
)
10 changes: 6 additions & 4 deletions pySEQ/weighting/_weight_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import polars as pl

def _weight_setup(self):
DT = self.DT
data = self.data
if not self.weight_preexpansion:
baseline_lag = self.data.select([self.treatment_col, self.id_col, self.time_col]) \
baseline_lag = data.select([self.treatment_col, self.id_col, self.time_col]) \
.sort([self.id_col, self.time_col]) \
.with_columns(pl.col(self.treatment_col)
.over(self.id_col)
Expand All @@ -11,14 +13,14 @@ def _weight_setup(self):
.drop(self.treatment_col) \
.rename({self.time_col : "period"})

fup0 = self.DT.filter(pl.col("followup" == 0)) \
fup0 = DT.filter(pl.col("followup") == 0) \
.join(
baseline_lag,
on = [self.id_col, "period"],
how = "inner"
)

fup = self.DT.sort([self.id_col, "trial", "followup"]) \
fup = DT.sort([self.id_col, "trial", "followup"]) \
.with_columns(pl.col(self.treatment_col)
.over([self.id_col, "trial"])
.shift(fill_value=self.treatment_level[0])
Expand All @@ -27,7 +29,7 @@ def _weight_setup(self):

WDT = pl.concat([fup0, fup])
else:
WDT = self.data.with_columns(pl.col(self.treatment_col)
WDT = data.with_columns(pl.col(self.treatment_col)
.over(self.id_col)
.shift(fill_value=self.treatment_level[0])
.alias("tx_lag"),
Expand Down
Loading