Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,55 @@
# pySEQ
# pySEQ - Sequentially Nested Target Trial Emulation

Implementation of sequential trial emulation for the analysis of
observational databases. The ‘SEQTaRget’ software accommodates
time-varying treatments and confounders, as well as binary and failure
time outcomes. ‘SEQTaRget’ allows to compare both static and dynamic
strategies, can be used to estimate observational analogs of
intention-to-treat and per-protocol effects, and can adjust for
potential selection bias.

## Installation
You can install the development version of pySEQ from github with:
```shell
pip install git+https://github.com/CausalInference/pySEQ
```
Or from pypi iwth
```shell
pip install pySEQ
```

## Setting up your Analysis
The primary API, `SEQuential` uses a dataclass system to handle function input. You can then recover elements as they are built by interacting with the `SEQuential` object you create.

From the user side, this amounts to creating a dataclass, `SEQopts`, and then feeding this into `SEQuential`. If you forgot to add something at class instantiation, you can, in some cases, add them when you call their respective class method.

```python
import polars as pl
from pySEQ import SEQuential, SEQopts

data = pl.from_pandas(SEQdata)
options = SEQopts(km_curves = True)

# Initiate the class
model = SEQuential(data,
id_col = "ID",
time_col = "time",
eligible_col = "eligible",
time_varying_cols = ["N", "L", "P"],
fixed_cols = ["sex"],
method = "ITT",
options = options)
model.expand() # Construct the nested structure
model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
model.fit() # Fit the model
model.survival() # Create survival curves
model.plot() # Create and show a plot of the survival curves

```

## Assumptions
There are several key assumptions in this package -
1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
2. `eligible_col`, `excused_column_names` and [TODO] are once 1, only 1 (with respect to `time_col`) flag variables.

11 changes: 4 additions & 7 deletions pySEQ/SEQopts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import multiprocessing
from dataclasses import dataclass, field
from typing import List, Optional
import math
from typing import List, Optional, Literal

@dataclass
class SEQopts:
bootstrap_nboot: int = 0
bootstrap_sample: float = 0.8
bootstrap_CI: float = 0.95
bootstrap_CI_method: str = "se"
bootstrap_CI_method: Literal["se", "percentile"] = "se"
cense_colname : Optional[str] = None
cense_denominator: Optional[str] = None
cense_numerator: Optional[str] = None
Expand Down Expand Up @@ -36,7 +35,7 @@ class SEQopts:
plot_labels: Optional[List[str]] = None
plot_subtitle: str = None
plot_title: str = None
plot_type: str = "risk"
plot_type: Literal["risk", "survival", "inc"] = "risk"
seed: Optional[int] = None
selection_first_trial: bool = False
selection_probability: float = 0.8
Expand Down Expand Up @@ -88,9 +87,7 @@ def __post_init__(self):
lists = [

]
# veryify some lists here
# merge param checker here into this class ? might be better


for i in ('covariates', 'numerator', 'denominator',
'cense_numerator', 'cense_denominator'):
attr = getattr(self, i)
Expand Down
28 changes: 18 additions & 10 deletions pySEQ/SEQuential.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Literal
import sys
import time
from dataclasses import asdict
Expand All @@ -7,7 +7,8 @@
import numpy as np

from .SEQopts import SEQopts
from .helpers import _col_string, bootstrap_loop, _format_time, _prepare_data
from .error import _param_checker
from .helpers import _col_string, bootstrap_loop, _format_time
from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator
from .expansion import _mapper, _binder, _dynamic, _randomSelection
from .weighting import _weight_setup, _fit_LTFU, _fit_numerator, _fit_denominator, _weight_bind, _weight_predict, _weight_stats
Expand All @@ -26,9 +27,9 @@ def __init__(
outcome_col: str,
time_varying_cols: Optional[List[str]] = None,
fixed_cols: Optional[List[str]] = None,
method: str = "ITT",
method: Literal["ITT", "dose-response", "censoring"] = "ITT",
parameters: Optional[SEQopts] = None
):
) -> None:
self.data = data
self.id_col = id_col
self.time_col = time_col
Expand Down Expand Up @@ -61,6 +62,8 @@ def __init__(

if self.cense_denominator is None:
self.cense_denominator = _cense_denominator()

_param_checker(self)

def expand(self):
start = time.perf_counter()
Expand Down Expand Up @@ -91,16 +94,14 @@ def expand(self):

self.DT = _binder(_mapper(self.data, self.id_col, self.time_col), self.data,
self.id_col, self.time_col, self.eligible_col, self.outcome_col,
self.treatment_col,
_col_string([self.covariates,
self.numerator, self.denominator,
self.cense_numerator, self.cense_denominator]).union(kept),
self.indicator_baseline, self.indicator_squared) \
.with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))
self.data = self.data.with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))

#self.data = _prepare_data(self, self.data)
#self.DT = _prepare_data(self, self.DT)

if self.method != "ITT":
_dynamic(self)
if self.selection_random:
Expand Down Expand Up @@ -130,6 +131,9 @@ def bootstrap(self, **kwargs):

@bootstrap_loop
def fit(self):
if self.bootstrap_nboot > 0 and not hasattr(self, "_boot_samples"):
raise ValueError("Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping.")

start = time.perf_counter()
if self.weighted:
WDT = _weight_setup(self)
Expand All @@ -148,8 +152,8 @@ def fit(self):
WDT = pl.from_pandas(WDT)
WDT = _weight_predict(self, WDT)
_weight_bind(self, WDT)

self.weight_stats = _weight_stats(self)
self.DT.write_csv("test.csv")

end = time.perf_counter()
self.model_time = _format_time(start, end)
Expand All @@ -161,6 +165,9 @@ def fit(self):
"weight")

def survival(self):
if not hasattr(self, "outcome_model") or not self.outcome_model:
raise ValueError("Outcome model not found. Please run the 'fit' method before calculating survival.")

start = time.perf_counter()

risk_data = _calculate_risk(self)
Expand All @@ -171,8 +178,9 @@ def survival(self):
self.survival_time = _format_time(start, end)

def hazard(self):
pass

if not hasattr(self, "outcome_model") or not self.outcome_model:
raise ValueError("Outcome model not found. Please run the 'fit' method before calculating hazard ratio.")

def plot(self):
self.km_graph = _survival_plot(self)
print(self.km_graph)
4 changes: 2 additions & 2 deletions pySEQ/analysis/_outcome_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def _outcome_fit(self,
weight_col: str = "weight",
):
if weighted:
df = self.DT.with_columns(
df = df.with_columns(
pl.col(weight_col).clip(
lower_bound=self.weight_min,
upper_bound=self.weight_max
Expand All @@ -27,7 +27,7 @@ def _outcome_fit(self,
formula=formula,
data=df_pd,
family=sm.families.Binomial(),
freq_weights=df_pd[weight_col])
var_weights=df_pd[weight_col])
else:
model = smf.glm(
formula=formula,
Expand Down
File renamed without changes.
11 changes: 11 additions & 0 deletions pySEQ/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from importlib.resources import files
import polars as pl

def load_data(name: str = "SEQdata") -> pl.DataFrame:

if name == "SEQdata":
data_path = files("pySEQ.data").joinpath("SEQdata.csv")
return pl.read_csv(data_path)

else:
raise ValueError(f"Dataset '{name}' not available. Options: ['SEQdata']")
1 change: 1 addition & 0 deletions pySEQ/error/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._param_checker import _param_checker
28 changes: 28 additions & 0 deletions pySEQ/error/_param_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
def _param_checker(self):
if self.subgroup_colname is not None and self.subgroup_colname not in self.fixed_cols:
raise ValueError("subgroup_colname must be included in fixed_cols.")

if self.survival_max is None:
self.survival_max = self.data.select(self.time_col).to_series().max()
if self.followup_max is None:
self.followup_max = self.data.select(self.time_col).to_series().max()

if self.excused_colnames is None and self.excused:
self.excused = False
raise Warning("Excused column names not provided but excused is set to True. Automatically set excused to False")

if self.excused_colnames is not None and not self.excused:
self.excused = True
raise Warning("Excused column names provided but excused is set to False. Automatically set excused to True")

if self.km_curves and self.hazard:
raise ValueError("km_curves and hazard cannot both be set to True.")

if sum([self.followup_class, self.followup_include, self.followup_spline]) > 1:
raise ValueError("Only one of followup_class or followup_include can be set to True.")

if self.weighted and self.method == "ITT" and self.cense_colname is None:
raise ValueError("For weighted ITT analyses, cense_colname must be provided.")

return

44 changes: 23 additions & 21 deletions pySEQ/expansion/_binder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import polars as pl

def _binder(DT, data, id_col, time_col, eligible_col, outcome_col, kept_cols,
def _binder(DT, data, id_col, time_col, eligible_col, outcome_col, treatment_col, kept_cols,
baseline_indicator, squared_indicator):
"""
Internal function to bind data to the map created by __mapper
Expand All @@ -15,7 +15,7 @@ def _binder(DT, data, id_col, time_col, eligible_col, outcome_col, kept_cols,
time_col,
f"{time_col}{squared_indicator}"}

cols = kept_cols.union({eligible_col, outcome_col})
cols = kept_cols.union({eligible_col, outcome_col, treatment_col})
cols = {col for col in cols if col is not None}

regular = {col for col in cols if not (baseline_indicator in col or squared_indicator in col) and col not in excluded}
Expand All @@ -33,28 +33,30 @@ def _binder(DT, data, id_col, time_col, eligible_col, outcome_col, kept_cols,
left_on=[id_col, 'period'],
right_on=[id_col, time_col],
how='left'
)
DT.sort([id_col, "trial", "followup"])

for i in ["trial", "followup"]:
colname = f"{i}{squared_indicator}"
DT = DT.with_columns((pl.col(i) ** 2).alias(colname))
)
DT = DT.sort([id_col, "trial", "followup"]) \
.with_columns([
(pl.col("trial") ** 2).alias(f"trial{squared_indicator}"),
(pl.col("followup") ** 2).alias(f"followup{squared_indicator}")
])

if squared:
squares = []
for sq in squared:
col = sq.replace(squared_indicator, '')
DT = DT.with_columns(
(pl.col(col) ** 2).alias(f"{col}{squared_indicator}")
)

if baseline:
base = [bas.replace(baseline_indicator, '') for bas in baseline] + [eligible_col]
for col in base:
DT = DT.with_columns(
pl.col(col).first().over([id_col, 'trial']).alias(f"{col}{baseline_indicator}")
)

DT = DT.filter(pl.col(f"{eligible_col}{baseline_indicator}") == 1) \
.drop([f"{eligible_col}{baseline_indicator}", eligible_col])
squares.append((pl.col(col) ** 2).alias(f"{col}{squared_indicator}"))
DT = DT.with_columns(squares)

baseline_cols = {bas.replace(baseline_indicator, '') for bas in baseline}
needed = {eligible_col, treatment_col}
baseline_cols.update({c for c in needed})

bas = [
pl.col(c).first().over([id_col, 'trial']).alias(f"{c}{baseline_indicator}")
for c in baseline_cols
]

DT = DT.with_columns(bas).filter(pl.col(f"{eligible_col}{baseline_indicator}") == 1) \
.drop([f"{eligible_col}{baseline_indicator}", eligible_col])

return DT
4 changes: 2 additions & 2 deletions pySEQ/expansion/_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ def _dynamic(self):
if self.method == "dose-response":
DT = self.DT.with_columns(
pl.col(self.treatment_col)
.cum_count()
.cum_sum()
.over([self.id_col, "trial"])
.alias("dose")
).with_columns([
(pl.col("dose") ** 2)
.alias(f"dose{self.squared_indicator}")
.alias(f"dose{self.indicator_squared}")
])
self.DT = DT

Expand Down
2 changes: 1 addition & 1 deletion pySEQ/initialization/_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def _outcome(self) -> str:
interaction = f"{tx_bas}*followup"
interaction_dose = "+".join(["followup*dose", f"followup*dose{self.indicator_squared}"])

if self.hazard or self.km_curves:
if self.hazard or not self.km_curves:
interaction = interaction_dose = None

tv_bas = (
Expand Down
10 changes: 5 additions & 5 deletions pySEQ/weighting/_weight_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ def _weight_bind(self, WDT):
join = "left"
on = [self.id_col, "trial", "followup"]

DT = self.DT.join(WDT,
on=on,
how=join)
WDT = self.DT.join(WDT,
on=on,
how=join)

if self.weight_preexpansion and self.excused:
trial = (pl.col("trial") == 0) & (pl.col("period") == 0)
excused = pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) > 0
else:
trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col))
trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col)) & (pl.col("followup") == 0)
excused = pl.lit(False)

override = (
Expand All @@ -28,7 +28,7 @@ def _weight_bind(self, WDT):
pl.col("numerator").is_null()
)

self.DT = DT.with_columns(
self.DT = WDT.with_columns(
pl.when(override)
.then(pl.lit(1.0))
.otherwise(pl.col("numerator") / pl.col("denominator"))
Expand Down
Loading