Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sequential/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def SEQopts(parallel: bool = False,
covariates: str = None,
numerator: str = None,
denominator: str = None,
cense_numerator: str = None,
cense_denominator: str = None,
censor_numerator: str = None,
censor_denominator: str = None,
weighted: bool = False,
weight_lower: float = -math.inf,
weight_upper: float = math.inf,
Expand Down
19 changes: 14 additions & 5 deletions sequential/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import polars as pl
from .SEQopts import SEQopts
from .helpers import __colString
from .initialization import __outcome, __numerator, __denominator, __censor_numerator, __censor_denominator
from .expansion import __mapper, __binder, __dynamic, __randomSelection
from .weighting import __weight_prepare_data, __weight_model, __weight_predict, __weight_bind, __weight_cumprod
Expand Down Expand Up @@ -35,7 +36,10 @@ def __init__(
self.weighted = parameters['weighted']
self.censor = parameters['censor']
self.random_selection = parameters['random_selection']
self.parameters = parameters
self.baseline_indicator = parameters['indicator_baseline']
self.squared_indicator = parameters['indicator_squared']
self.excused_col0 = parameters['excused_col0']
self.excused_col1 = parameters['excused_col1']

if parameters['covariates'] is None:
self.covariates = __outcome()
Expand All @@ -52,15 +56,19 @@ def __init__(

if self.censor is not None:
if self.parameters['censor_numerator'] is None:
self.cense_numerator = __censor_numerator()
else: self.cense_numerator = self.parameters['censor_numerator']
self.censor_numerator = __censor_numerator()
else: self.censor_numerator = self.parameters['censor_numerator']

if parameters['censor_denominator'] is None:
self.cense_denominator = __censor_denominator()
self.cenor_denominator = __censor_denominator()
else: self.censor_denominator = parameters['censor_denominator']

def expand(self):
self.DT = __binder(__mapper(self.data))
self.DT = __binder(__mapper(self.data), self.data, __colString([
self.covariates, self.numerator, self.denominator, self.censor_numerator, self.censor_denominator
]), self.eligible_col, self.excused_col0, self.excused_col1,
self.baseline_indicator, self.squared_indicator)

if self.method != "ITT":
self.DT = __dynamic(self.DT)
if self.random_selection:
Expand All @@ -79,6 +87,7 @@ def weight(self):


def outcome():
pass

def survival():
pass
Expand Down
7 changes: 7 additions & 0 deletions sequential/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .SEQuential import SEQuential
from .SEQopts import SEQopts

__all__ = [
"SEQuential",
"SEQopts"
]
46 changes: 44 additions & 2 deletions sequential/expansion/__binder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,46 @@
def __binder(DT, time_col, tv_cols, fixed_cols, parameters):
import polars as pl

def __binder(DT, data, id_col, time_col,
eligible_col, excused_col0, excused_col1, cols,
baseline_indicator, squared_indicator):
"""
Internal function to bind data to the map created by __mapper
"""
"""
excluded = {'dose',
f'dose{squared_indicator}',
time_col,
f'{time_col}{squared_indicator}',
'tx_lag'}
cols = cols.union({eligible_col, excused_col0, excused_col1})
cols = {col for col in cols if col is not None}

regular = {col for col in cols if not (baseline_indicator in col or squared_indicator in col) and col not in excluded}
baseline = {col for col in cols if baseline_indicator in col and col not in excluded}
squared = {col for col in cols if squared_indicator in col and col not in excluded}

DT = DT.join(
data.select([id_col, time_col] + list(regular)),
left_on=[id_col, 'period'],
right_on=[id_col, time_col],
how='left'
)

if squared:
for sq in squared:
col = sq.replace(squared_indicator, '')
DT = DT.with_columns(
(pl.col(col) ** 2).alias(f"{col}{squared_indicator}")
)

if baseline:
base = [bas.replace(baseline_indicator, '') for bas in baseline] + [eligible_col]

for col in base:
DT = DT.with_columns(
pl.col(col).over([id_col, 'trial']).first().alias(f"{col}{baseline_indicator}")
)

DT = DT.filter(pl.col(f"{eligible_col}{baseline_indicator}") == 1) \
.drop([f"{eligible_col}{baseline_indicator}", eligible_col])

return DT
32 changes: 30 additions & 2 deletions sequential/expansion/__dynamic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,32 @@
def __dynamic(DT, parameters):
import polars as pl
def __dynamic(DT, id_col, time_col, treatment_col, method,
excused_col0, excused_col1,
squared_indicator, baseline_indicator):
"""
Handles special cases for the data from the __mapper -> __binder pipeline
"""
"""
if method == "dose-response":
DT = DT.with_columns(
pl.col(treatment_col).cum_count().over([id_col, "trial"]).alias("dose")
).with_columns([
(pl.col("dose") ** 2).alias(f"dose{squared_indicator}")
])
elif method == "censoring":
DT = DT.with_columns(
pl.col(treatment_col)
.first()
.over([id_col, "trial"])
.alias("temp")
).with_columns(
pl.col(treatment_col)
.shift(1)
.over([id_col, "trial"])
.alias("tx_lag")
).with_columns(
pl.when(pl.col("temp").is_null())
.then(pl.col("temp"))
.otherwise(pl.col("tx_lag"))
.alias("tx_lag")
).with_columns(
(pl.col(treatment_col) != pl.col("tx_lag")).alias("switch")
)
7 changes: 7 additions & 0 deletions sequential/expansion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,10 @@
from .__dynamic import __dynamic
from .__mapper import __mapper
from .__selection import __randomSelection

__all__ = [
"__binder",
"__dynamic",
"__mapper",
"__randomSelection"
]
28 changes: 16 additions & 12 deletions sequential/expansion/__mapper.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
import polars as pl
import math

def __mapper(data, id_col, time_col, parameters):
def __mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.inf):
"""
Internal function to create the expanded map to bind data to.

"""

DT = (
data.with_columns([
pl.count().over(id_col).alias("group_count"),
pl.col(id_col).cum_count().over(id_col).alias("trial")

data.select([pl.col(id_col), pl.col(time_col)])
.with_columns([
pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")
])
.with_columns([
pl.struct([pl.col(time_col), pl.col("group_count")])
.apply(lambda s: list(range(s[time_col], s["group_count"])))
pl.struct([pl.col(time_col), pl.col(time_col).max().over(id_col).alias("max_time")])
.map_elements(lambda x: list(range(x[time_col], x["max_time"] + 1)),
return_dtype=pl.List(pl.Int64))
.alias("period")
])
.explode("period")
.drop(pl.col(time_col))
.with_columns([
pl.cum_count().over([id_col, "trial"]).alias("followup")
pl.col(id_col).cum_count().over([id_col, "trial"]).sub(1).alias("followup")
])
.filter((pl.col("followup") >= parameters["followup_min"]) &
(pl.col("followup") <= parameters["followup_max"]))
.filter(
(pl.col("followup") >= min_followup) &
(pl.col("followup") <= max_followup)
)
)
return DT
return DT
6 changes: 6 additions & 0 deletions sequential/helpers/__colString.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def __colString(expressions):
cols = set()
for expression in expressions:
cols.update(expression.replace("+", " ").replace("*", " ").split())

return cols
1 change: 1 addition & 0 deletions sequential/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from __colString import colString