Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions .github/workflows/pylint.yml

This file was deleted.

78 changes: 38 additions & 40 deletions pySEQ/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,50 @@

@dataclass
class SEQopts:
bootstrap_nboot: int = 0 #x
bootstrap_sample: float = 0.8 #x
bootstrap_CI: float = 0.95 #x
bootstrap_CI_method: Literal["se", "percentile"] = "se" #x
bootstrap_nboot: int = 0
bootstrap_sample: float = 0.8
bootstrap_CI: float = 0.95
bootstrap_CI_method: Literal["se", "percentile"] = "se"
cense_colname : Optional[str] = None #TODO
cense_denominator: Optional[str] = None #x
cense_numerator: Optional[str] = None #x
cense_denominator: Optional[str] = None
cense_numerator: Optional[str] = None
cense_eligible_colname: Optional[str] = None #TODO
compevent_colname: Optional[str] = None #TODO
covariates: Optional[List[str]] = None #x - need to test
denominator: Optional[List[str]] = None #x - need to test
excused: bool = False # x
excused_colnames: List[str] = field(default_factory=lambda: []) #x
followup_class: bool = False #TODO
followup_include: bool = True #x
followup_max: int = None #x
followup_min: int = 0 #x
covariates: Optional[List[str]] = None # need to test
denominator: Optional[List[str]] = None #need to test
excused: bool = False
excused_colnames: List[str] = field(default_factory=lambda: [])
followup_class: bool = False
followup_include: bool = True
followup_max: int = None
followup_min: int = 0
followup_spline: bool = False #TODO
hazard: bool = False
indicator_baseline: str = "_bas" #x
indicator_squared: str = "_sq" #x
km_curves: bool = False #x
hazard: bool = False # TODO
indicator_baseline: str = "_bas"
indicator_squared: str = "_sq"
km_curves: bool = False
multinomial: bool = False # - this can maybe be removed since statsmodels seems to be handling it?
ncores: int = multiprocessing.cpu_count() #x
numerator: Optional[List[str]] = None # x - need to test
parallel: bool = False # - maybe wrap this into ncores > 1
plot_colors: List[str] = field(default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]) #x
plot_labels: List[str] = field(default_factory=lambda: []) #x
plot_title: str = None #x
plot_type: Literal["risk", "survival", "inc"] = "risk" #x
seed: Optional[int] = None #x
selection_first_trial: bool = False #TODO
selection_probability: float = 0.8 #x
selection_random: bool = False #x
ncores: int = multiprocessing.cpu_count()
numerator: Optional[List[str]] = None # need to test
parallel: bool = False
plot_colors: List[str] = field(default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"])
plot_labels: List[str] = field(default_factory=lambda: [])
plot_title: str = None
plot_type: Literal["risk", "survival", "inc"] = "risk" # add inc (compevent)
seed: Optional[int] = None
selection_first_trial: bool = False
selection_probability: float = 0.8
selection_random: bool = False
subgroup_colname: str = None #TODO
survival_max: int = None #TODO
survival_min: int = 0#TODO
treatment_level: List[int] = field(default_factory=lambda: [0, 1]) #x
trial_include: bool = True #x
weight_eligible_colnames: List[str] = field(default_factory=lambda: []) #TODO
weight_min: float = 0.0 #x
weight_max: float = None #x
weight_lag_condition: bool = False #x
weight_p99: bool = False #x
weight_preexpansion: bool = False #x
weighted: bool = False #x
treatment_level: List[int] = field(default_factory=lambda: [0, 1])
trial_include: bool = True
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
weight_min: float = 0.0
weight_max: float = None
weight_lag_condition: bool = True
weight_p99: bool = False
weight_preexpansion: bool = False
weighted: bool = False

def __post_init__(self):
bools = [
Expand Down
76 changes: 41 additions & 35 deletions pySEQ/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .error import _param_checker
from .helpers import _col_string, bootstrap_loop, _format_time
from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator
from .expansion import _mapper, _binder, _dynamic, _randomSelection
from .expansion import _mapper, _binder, _dynamic, _random_selection, _first_trial
from .weighting import _weight_setup, _fit_LTFU, _fit_numerator, _fit_denominator, _weight_bind, _weight_predict, _weight_stats
from .analysis import _outcome_fit, _calculate_risk, _calculate_survival
from .plot import _survival_plot
Expand Down Expand Up @@ -72,44 +72,50 @@ def expand(self):
*self.weight_eligible_colnames,
*self.excused_colnames]

if not self.selection_first_trial:
self.data = self.data.with_columns([
pl.when(pl.col(self.treatment_col).is_in(self.treatment_level))
.then(self.eligible_col)
.otherwise(0)
.alias(self.eligible_col),
pl.col(self.treatment_col)
.shift(1)
.over([self.id_col])
.alias("tx_lag"),
pl.lit(False).alias("switch")
]).with_columns([
pl.when(pl.col(self.time_col) == 0)
.then(pl.lit(False))
.otherwise(
(pl.col("tx_lag").is_not_null()) &
(pl.col("tx_lag") != pl.col(self.treatment_col))
).cast(pl.Int8)
.alias("switch")
])

self.DT = _binder(_mapper(self.data, self.id_col, self.time_col), self.data,
self.id_col, self.time_col, self.eligible_col, self.outcome_col,
self.treatment_col,
_col_string([self.covariates,
self.numerator, self.denominator,
self.cense_numerator, self.cense_denominator]).union(kept),
self.indicator_baseline, self.indicator_squared) \
.with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))
self.data = self.data.with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))
else:
#only first trial selection here
pass
self.data = self.data.with_columns([
pl.when(pl.col(self.treatment_col).is_in(self.treatment_level))
.then(self.eligible_col)
.otherwise(0)
.alias(self.eligible_col),
pl.col(self.treatment_col)
.shift(1)
.over([self.id_col])
.alias("tx_lag"),
pl.lit(False).alias("switch")
]).with_columns([
pl.when(pl.col(self.time_col) == 0)
.then(pl.lit(False))
.otherwise(
(pl.col("tx_lag").is_not_null()) &
(pl.col("tx_lag") != pl.col(self.treatment_col))
).cast(pl.Int8)
.alias("switch")
])

self.DT = _binder(self,kept_cols= _col_string([self.covariates,
self.numerator,
self.denominator,
self.cense_numerator,
self.cense_denominator]).union(kept)) \
.with_columns(
pl.col(self.id_col)
.cast(pl.Utf8)
.alias(self.id_col)
)

self.data = self.data.with_columns(
pl.col(self.id_col)
.cast(pl.Utf8)
.alias(self.id_col)
)

if self.method != "ITT":
_dynamic(self)
if self.selection_random:
_randomSelection(self)
_random_selection(self)
if self.followup_class:
self.fixed_cols.append(["followup",
f"followup{self.indicator_squared}"])

end = time.perf_counter()
self.expansion_time = _format_time(start, end)
Expand Down
9 changes: 1 addition & 8 deletions pySEQ/expansion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
from ._binder import _binder
from ._dynamic import _dynamic
from ._mapper import _mapper
from ._selection import _randomSelection

__all__ = [
"_binder",
"_dynamic",
"_mapper",
"__randomSelection"
]
from ._selection import _random_selection
67 changes: 38 additions & 29 deletions pySEQ/expansion/_binder.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,71 @@
import polars as pl
from ._mapper import _mapper

def _binder(DT, data, id_col, time_col, eligible_col, outcome_col, treatment_col, kept_cols,
baseline_indicator, squared_indicator):
def _binder(self, kept_cols):
"""
Internal function to bind data to the map created by __mapper
"""
excluded = {"dose",
f"dose{squared_indicator}",
f"dose{self.indicator_squared}",
"followup",
f"followup{squared_indicator}",
f"followup{self.indicator_squared}",
"tx_lag",
"trial",
f"trial{squared_indicator}",
time_col,
f"{time_col}{squared_indicator}"}
f"trial{self.indicator_squared}",
self.time_col,
f"{self.time_col}{self.indicator_squared}"}

cols = kept_cols.union({eligible_col, outcome_col, treatment_col})
cols = kept_cols.union({self.eligible_col, self.outcome_col, self.treatment_col})
cols = {col for col in cols if col is not None}

regular = {col for col in cols if not (baseline_indicator in col or squared_indicator in col) and col not in excluded}
regular = {col for col in cols if not (self.indicator_baseline in col or self.indicator_squared in col) and col not in excluded}

baseline = {col for col in cols if baseline_indicator in col and col not in excluded}
bas_kept = {col.replace(baseline_indicator, "") for col in baseline}
baseline = {col for col in cols if self.indicator_baseline in col and col not in excluded}
bas_kept = {col.replace(self.indicator_baseline, "") for col in baseline}

squared = {col for col in cols if squared_indicator in col and col not in excluded}
sq_kept = {col.replace(squared_indicator, "") for col in squared}
squared = {col for col in cols if self.indicator_squared in col and col not in excluded}
sq_kept = {col.replace(self.indicator_squared, "") for col in squared}

kept = list(regular.union(bas_kept).union(sq_kept))

DT = DT.join(
data.select([id_col, time_col] + kept),
left_on=[id_col, 'period'],
right_on=[id_col, time_col],
how='left'
)
DT = DT.sort([id_col, "trial", "followup"]) \
if self.selection_first_trial:
DT = self.data.sort([self.id_col, self.time_col]) \
.with_columns([
pl.col(self.time_col).alias("period"),
pl.col(self.time_col).alias("followup"),
pl.lit(0).alias("trial")
]).drop(self.time_col)
else:
DT = _mapper(self.data, self.id_col, self.time_col)
DT = DT.join(
self.data.select([self.id_col, self.time_col] + kept),
left_on=[self.id_col, 'period'],
right_on=[self.id_col, self.time_col],
how='left'
)
DT = DT.sort([self.id_col, "trial", "followup"]) \
.with_columns([
(pl.col("trial") ** 2).alias(f"trial{squared_indicator}"),
(pl.col("followup") ** 2).alias(f"followup{squared_indicator}")
(pl.col("trial") ** 2).alias(f"trial{self.indicator_squared}"),
(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")
])

if squared:
squares = []
for sq in squared:
col = sq.replace(squared_indicator, '')
squares.append((pl.col(col) ** 2).alias(f"{col}{squared_indicator}"))
col = sq.replace(self.indicator_squared, '')
squares.append((pl.col(col) ** 2).alias(f"{col}{self.indicator_squared}"))
DT = DT.with_columns(squares)

baseline_cols = {bas.replace(baseline_indicator, '') for bas in baseline}
needed = {eligible_col, treatment_col}
baseline_cols = {bas.replace(self.indicator_baseline, '') for bas in baseline}
needed = {self.eligible_col, self.treatment_col}
baseline_cols.update({c for c in needed})

bas = [
pl.col(c).first().over([id_col, 'trial']).alias(f"{c}{baseline_indicator}")
pl.col(c).first().over([self.id_col, 'trial']).alias(f"{c}{self.indicator_baseline}")
for c in baseline_cols
]

DT = DT.with_columns(bas).filter(pl.col(f"{eligible_col}{baseline_indicator}") == 1) \
.drop([f"{eligible_col}{baseline_indicator}", eligible_col])
DT = DT.with_columns(bas).filter(pl.col(f"{self.eligible_col}{self.indicator_baseline}") == 1) \
.drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col])

return DT
4 changes: 1 addition & 3 deletions pySEQ/expansion/_selection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import polars as pl
import numpy as np

def _randomSelection(self):
def _random_selection(self):
"""
Handles the case where random selection is applied for data from
the __mapper -> __binder -> optionally __dynamic pipeline
Expand Down
13 changes: 7 additions & 6 deletions pySEQ/weighting/_weight_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _fit_numerator(self, WDT):
return
if self.method == "ITT":
return
predictor = "switch" if self.excused and not self.weight_preexpansion else self.treatment_col
predictor = "switch" if self.excused else self.treatment_col
formula = f"{predictor}~{self.numerator}"
tx_bas = f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag"
fits = []
Expand All @@ -33,9 +33,10 @@ def _fit_numerator(self, WDT):
DT_subset = WDT[WDT[self.excused_colnames[i]] == 0]
else:
DT_subset = WDT

if self.weight_lag_condition:
DT_subset = DT_subset[DT_subset[tx_bas] == i]
if self.weight_eligible_colnames[i] is not None:
DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1]

model = smf.mnlogit(
formula,
Expand All @@ -51,18 +52,18 @@ def _fit_denominator(self, WDT):
return
predictor = "switch" if self.excused and not self.weight_preexpansion else self.treatment_col
formula = f"{predictor}~{self.denominator}"
tx_bas = "tx_lag"
fits = []
for i in self.treatment_level:
if self.excused and self.excused_colnames[i] is not None:
DT_subset = WDT[WDT[self.excused_colnames[i]] == 0]
else:
DT_subset = WDT
if self.weight_lag_condition:
DT_subset = DT_subset[DT_subset[tx_bas] == i]

if not self.weight_preexpansion:
DT_subset = DT_subset[DT_subset["tx_lag"] == i]
if not self.weight_preexpansion and not self.excused:
DT_subset = DT_subset[DT_subset['followup'] != 0]
if self.weight_eligible_colnames[i] is not None:
DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1]

model = smf.mnlogit(
formula,
Expand Down
Loading