Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pySEQ/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SEQopts:
data_return: bool = False
denominator: Optional[List[str]] = None
excused: bool = False
excused_colnames: Optional[List[str]] = None
excused_colnames: List[str] = field(default_factory=lambda: [])
followup_class: bool = False
followup_include: bool = True
followup_max: int = None
Expand All @@ -32,7 +32,7 @@ class SEQopts:
numerator: Optional[List[str]] = None
parallel: bool = False
plot_colors: List[str] = field(default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"])
plot_labels: Optional[List[str]] = None
plot_labels: List[str] = field(default_factory=lambda: [])
plot_subtitle: str = None
plot_title: str = None
plot_type: Literal["risk", "survival", "inc"] = "risk"
Expand All @@ -45,7 +45,7 @@ class SEQopts:
survival_min: int = 0
treatment_level: List[int] = field(default_factory=lambda: [0, 1])
trial_include: bool = True
weight_eligible_colnames: Optional[List[str]] = None
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
weight_min: float = 0.0
weight_max: float = None
weight_lag_condition: bool = False
Expand Down
14 changes: 7 additions & 7 deletions pySEQ/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(
def expand(self):
start = time.perf_counter()
kept = [self.cense_colname, self.cense_eligible_colname,
self.compevent_colname,
self.subgroup_colname,
self.weight_eligible_colnames]
self.compevent_colname,
*self.weight_eligible_colnames,
*self.excused_colnames]

self.data = self.data.with_columns([
pl.when(pl.col(self.treatment_col).is_in(self.treatment_level))
Expand Down Expand Up @@ -105,7 +105,8 @@ def expand(self):
if self.method != "ITT":
_dynamic(self)
if self.selection_random:
_randomSelection(self.DT)
_randomSelection(self)

end = time.perf_counter()
self.expansion_time = _format_time(start, end)

Expand All @@ -118,13 +119,13 @@ def bootstrap(self, **kwargs):
else:
raise ValueError(f"Unknown argument: {key}")

rng = np.random.RandomState(self.seed) if self.seed is not None else np.random
self._rng = np.random.RandomState(self.seed) if self.seed is not None else np.random
UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list()
NIDs = len(UIDs)

self._boot_samples = []
for _ in range(self.bootstrap_nboot):
sampled_IDs = rng.choice(UIDs, size=int(self.bootstrap_sample * NIDs), replace=True)
sampled_IDs = self._rng.choice(UIDs, size=int(self.bootstrap_sample * NIDs), replace=True)
id_counts = Counter(sampled_IDs)
self._boot_samples.append(id_counts)
return self
Expand Down Expand Up @@ -153,7 +154,6 @@ def fit(self):
WDT = _weight_predict(self, WDT)
_weight_bind(self, WDT)
self.weight_stats = _weight_stats(self)
self.DT.write_csv("test.csv")

end = time.perf_counter()
self.model_time = _format_time(start, end)
Expand Down
4 changes: 2 additions & 2 deletions pySEQ/error/_param_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ def _param_checker(self):
if self.followup_max is None:
self.followup_max = self.data.select(self.time_col).to_series().max()

if self.excused_colnames is None and self.excused:
if len(self.excused_colnames) == 0 and self.excused:
self.excused = False
raise Warning("Excused column names not provided but excused is set to True. Automatically set excused to False")

if self.excused_colnames is not None and not self.excused:
if len(self.excused_colnames) > 0 and not self.excused:
self.excused = True
raise Warning("Excused column names provided but excused is set to False. Automatically set excused to True")

Expand Down
32 changes: 22 additions & 10 deletions pySEQ/expansion/_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import polars as pl

def _dynamic(self):
"""
Handles special cases for the data from the __mapper -> __binder pipeline
Expand All @@ -17,7 +18,10 @@ def _dynamic(self):

elif self.method == "censoring":
DT = self.DT.sort([self.id_col, "trial", "followup"]).with_columns(
pl.col(self.treatment_col).shift(1).over([self.id_col, "trial"]).alias("tx_lag")
pl.col(self.treatment_col)
.shift(1)
.over([self.id_col, "trial"])
.alias("tx_lag")
)

switch = (
Expand All @@ -28,7 +32,7 @@ def _dynamic(self):
(pl.col("tx_lag") != pl.col(self.treatment_col))
)
)

is_excused = pl.lit(False)
if self.excused:
conditions = []
for i in range(len(self.treatment_level)):
Expand All @@ -41,21 +45,29 @@ def _dynamic(self):

if conditions:
excused = pl.any_horizontal(conditions)
is_excused = switch & excused
switch = (
pl.when(pl.any_horizontal(conditions))
pl.when(excused)
.then(pl.lit(False))
.otherwise(switch)
)

DT = DT.with_columns([
switch.alias("switch"),
excused.alias("isExcused") if self.excused else pl.lit(False).alias("isExcused")
]).filter(
pl.col("switch").cum_max().over([self.id_col, "trial"])
.shift(1, fill_value=False)
== 0
is_excused.alias("isExcused")
]).sort([self.id_col, "trial", "followup"]) \
.filter(
(
pl.col("switch")
.cum_max()
.shift(1, fill_value=False)
)
.over([self.id_col, "trial"])
== 0
).with_columns(
pl.col("switch").cast(pl.Int8).alias("switch")
pl.col("switch")
.cast(pl.Int8)
.alias("switch")
)

self.DT = DT.drop(["tx_lag"])
30 changes: 28 additions & 2 deletions pySEQ/expansion/_selection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
def _randomSelection(DT, parameters):
import polars as pl
import numpy as np

def _randomSelection(self):
"""
Handles the case where random selection is applied for data from
the __mapper -> __binder -> optionally __dynamic pipeline
"""
"""
UIDs = self.DT.select([
self.id_col,
"trial",
f"{self.treatment_col}{self.indicator_baseline}"]) \
.with_columns(
(pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID")) \
.filter(
pl.col(f"{self.treatment_col}{self.indicator_baseline}") == 0) \
.unique("trialID").to_series().to_list()

NIDs = len(UIDs)
sample = self._rng.choice(
UIDs,
size=int(self.selection_probability * NIDs),
replace=False
)

self.DT = self.DT.with_columns(
(pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID")
).filter(
pl.col("trialID").is_in(sample)
).drop("trialID")

2 changes: 1 addition & 1 deletion pySEQ/helpers/_predict_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ def _predict_model(self, model, newdata):
for col in self.fixed_cols:
if col in newdata.columns:
newdata[col] = newdata[col].astype("category")
return np.array(model.predict(newdata)).flatten()
return np.array(model.predict(newdata))
58 changes: 36 additions & 22 deletions pySEQ/weighting/_weight_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,51 @@ def _weight_bind(self, WDT):
else:
join = "left"
on = [self.id_col, "trial", "followup"]

WDT = self.DT.join(WDT,
on=on,
how=join)


WDT = self.DT.join(WDT, on=on, how=join)

if self.weight_preexpansion and self.excused:
trial = (pl.col("trial") == 0) & (pl.col("period") == 0)
excused = pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) > 0
override = (
trial |
excused |
pl.col(self.outcome_col).is_null() |
(pl.col("denominator") < 1e-7)
)
elif not self.weight_preexpansion and self.excused:
trial = pl.col("followup") == 0
excused = pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) > 0
override = (
trial |
excused |
pl.col(self.outcome_col).is_null() |
(pl.col("denominator") < 1e-7) |
(pl.col("numerator") < 1e-7)
)
else:
trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col)) & (pl.col("followup") == 0)
excused = pl.lit(False)

override = (
trial |
excused |
pl.col(self.outcome_col).is_null() |
(pl.col("denominator") < 1e-15) |
pl.col("numerator").is_null()
)

override = (
trial |
excused |
pl.col(self.outcome_col).is_null() |
(pl.col("denominator") < 1e-15) |
pl.col("numerator").is_null()
)

self.DT = WDT.with_columns(
pl.when(override)
.then(pl.lit(1.0))
.otherwise(pl.col("numerator") / pl.col("denominator"))
.alias("wt")
.then(pl.lit(1.0))
.otherwise(pl.col("numerator") / pl.col("denominator"))
.alias("wt")
).sort(
[self.id_col, "trial", "followup"]
).with_columns(
pl.col("wt")
.fill_null(1.0)
.cum_prod()
.over([self.id_col, "trial"])
.alias("weight")
)
.fill_null(1.0)
.cum_prod()
.over([self.id_col, "trial"])
.alias("weight")
)

20 changes: 16 additions & 4 deletions pySEQ/weighting/_weight_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@ def _fit_numerator(self, WDT):
return
predictor = "switch" if self.excused and not self.weight_preexpansion else self.treatment_col
formula = f"{predictor}~{self.numerator}"
tx_bas = f"{self.treatment_col}{self.indicator_baseline}" if not self.weight_preexpansion else "tx_lag"
tx_bas = f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag"
fits = []
for i in self.treatment_level:
DT_subset = WDT[WDT[tx_bas] == i]
if self.excused and self.excused_colnames[i] is not None:
DT_subset = WDT[WDT[self.excused_colnames[i]] == 0]
else:
DT_subset = WDT

if self.weight_lag_condition:
DT_subset = DT_subset[DT_subset[tx_bas] == i]

model = smf.mnlogit(
formula,
DT_subset
Expand All @@ -44,10 +51,15 @@ def _fit_denominator(self, WDT):
return
predictor = "switch" if self.excused and not self.weight_preexpansion else self.treatment_col
formula = f"{predictor}~{self.denominator}"
tx_bas = f"{self.treatment_col}{self.indicator_baseline}" if not self.weight_preexpansion else "tx_lag"
tx_bas = "tx_lag"
fits = []
for i in self.treatment_level:
DT_subset = WDT[WDT[tx_bas] == i]
if self.excused and self.excused_colnames[i] is not None:
DT_subset = WDT[WDT[self.excused_colnames[i]] == 0]
else:
DT_subset = WDT
if self.weight_lag_condition:
DT_subset = DT_subset[DT_subset[tx_bas] == i]

if not self.weight_preexpansion:
DT_subset = DT_subset[DT_subset['followup'] != 0]
Expand Down
63 changes: 36 additions & 27 deletions pySEQ/weighting/_weight_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ def _weight_predict(self, WDT):
grouping = [self.id_col]
grouping += ["trial"] if not self.weight_preexpansion else []
time = self.time_col if self.weight_preexpansion else "followup"
classes = len(self.treatment_level)

if self.excused:
# TODO
pass

if self.method == "ITT":
WDT = WDT.with_columns([
Expand All @@ -25,22 +20,36 @@ def _weight_predict(self, WDT):

for i, level in enumerate(self.treatment_level):
mask = pl.col("tx_lag") == level
tx_lag_mask = (WDT["tx_lag"] == level).to_numpy()

if self.denominator_model[i] is not None:
p = _predict_model(self, self.denominator_model[i], WDT) \
.reshape(WDT.height, classes)[:, i]
if not self.weight_preexpansion:
pred_mask = tx_lag_mask & (WDT["followup"] != 0).to_numpy()
else:
pred_mask = tx_lag_mask

switched_treatment = (WDT[self.treatment_col] != WDT["tx_lag"]).to_numpy()
pred_denom = np.where(switched_treatment, 1. - p, p)
pred_denom = np.ones(WDT.height)
if pred_mask.sum() > 0:
subset = WDT.filter(pl.Series(pred_mask))
p = _predict_model(self, self.denominator_model[i], subset)
if p.ndim == 1:
p = p.reshape(-1, 1)
p = p[:, i]
switched_treatment = (subset[self.treatment_col] != subset["tx_lag"]).to_numpy()
pred_denom[pred_mask] = np.where(switched_treatment, 1. - p, p)
else:
pred_denom = np.ones(WDT.height)

if self.numerator_model[i] is not None:
p = _predict_model(self, self.numerator_model[i], WDT) \
.reshape(WDT.height, classes)[:, i]

switched_treatment = (WDT[self.treatment_col] != WDT["tx_lag"]).to_numpy()
pred_num = np.where(switched_treatment, 1. - p, p)
if hasattr(self, "numerator_model") and self.numerator_model[i] is not None:
pred_num = np.ones(WDT.height)
if tx_lag_mask.sum() > 0:
subset = WDT.filter(pl.Series(tx_lag_mask))
p = _predict_model(self, self.numerator_model[i], subset)
if p.ndim == 1:
p = p.reshape(-1, 1)
p = p[:, i]
switched_treatment = (subset[self.treatment_col] != subset["tx_lag"]).to_numpy()
pred_num[tx_lag_mask] = np.where(switched_treatment, 1. - p, p)
else:
pred_num = np.ones(WDT.height)

Expand All @@ -54,18 +63,18 @@ def _weight_predict(self, WDT):
.otherwise(pl.col("denominator"))
.alias("denominator")
])

if self.cense_colname is not None:
p_num = _predict_model(self, self.cense_numerator, WDT)
p_denom = _predict_model(self, self.cense_denominator, WDT)
WDT = WDT.with_columns([
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom)
]).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias("cense")
)
else:
WDT = WDT.with_columns(pl.lit(1.).alias("cense"))
if self.cense_colname is not None:
p_num = _predict_model(self, self.cense_numerator, WDT).flatten()
p_denom = _predict_model(self, self.cense_denominator, WDT).flatten()
WDT = WDT.with_columns([
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom)
]).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias("cense")
)
else:
WDT = WDT.with_columns(pl.lit(1.).alias("cense"))

kept = ["numerator", "denominator", "cense", self.id_col, "trial", time, "tx_lag"]
exists = [col for col in kept if col in WDT.columns]
Expand Down
Loading