Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
model.fit() # Fit the model
model.survival() # Create survival curves
model.plot() # Create and show a plot of the survival curves
model.collect() # Collection of important information

```

Expand Down
12 changes: 6 additions & 6 deletions pySEQ/SEQoutput.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from dataclasses import dataclass
from typing import List, Optional, Literal
from .SEQopts import SEQopts
import statsmodels.formula.api as smf
from statsmodels.base.wrapper import ResultsWrapper
import polars as pl
import matplotlib.figure

@dataclass
class SEQoutput:
options: SEQopts = None
method: str = None
numerator_models: List[smf.MNLogit] = None
denominator_models: List[smf.MNLogit] = None
outcome_models: List[List[smf.glm]] = None
compevent_models: List[List[smf.glm]] = None
numerator_models: List[ResultsWrapper] = None
denominator_models: List[ResultsWrapper] = None
outcome_models: List[List[ResultsWrapper]] = None
compevent_models: List[List[ResultsWrapper]] = None
weight_statistics: dict = None
hazard: pl.DataFrame = None
km_data: pl.DataFrame = None
Expand Down Expand Up @@ -78,7 +78,7 @@ def retrieve_data(self,
case _:
data = self.km_data
if data is None:
ValueError("Data {type} was not created in the SEQuential process")
raise ValueError("Data {type} was not created in the SEQuential process")
return data


5 changes: 3 additions & 2 deletions pySEQ/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .SEQopts import SEQopts
from .SEQoutput import SEQoutput
from .error import _param_checker
from .error import _param_checker, _datachecker
from .helpers import _col_string, bootstrap_loop, _format_time
from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator
from .expansion import _binder, _dynamic, _random_selection, _diagnostics
Expand Down Expand Up @@ -69,6 +69,7 @@ def __init__(
self.cense_denominator = _cense_denominator(self)

_param_checker(self)
_datachecker(self)

def expand(self):
start = time.perf_counter()
Expand Down Expand Up @@ -216,7 +217,7 @@ def collect(self):
"numerator_model", "denominator_model",
"outcome_model",
"hazard_ratio", "risk_estimates",
"km_data", "diagnostics",
"km_data", "km_graph", "diagnostics",
"_survival_time", "_hazard_time",
"_model_time", "_expansion_time",
"weight_stats"
Expand Down
3 changes: 2 additions & 1 deletion pySEQ/error/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from ._param_checker import _param_checker
from ._param_checker import _param_checker
from ._datachecker import _datachecker
29 changes: 29 additions & 0 deletions pySEQ/error/_datachecker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import polars as pl

def _datachecker(self):
check = self.data.group_by(self.id_col).agg([
pl.len().alias("row_count"),
pl.col(self.time_col).max().alias("max_time")
])

invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1)
if len(invalid) > 0:
raise ValueError(
f"Data validation failed: {len(invalid)} ID(s) have mismatched "
f"This suggests invalid times"
f"Invalid IDs:\n{invalid}"
)

for col in self.excused_colnames:
violations = self.data.sort([self.id_col, self.time_col]).group_by(self.id_col).agg([
((pl.col(col).cum_sum().shift(1, fill_value=0) > 0) & (pl.col(col) == 0))
.any()
.alias("has_violation")
]).filter(pl.col("has_violation"))

if len(violations) > 0:
raise ValueError(
f"Column '{col}' violates 'once one, always one' rule for excusing treatment "
f"{len(violations)} ID(s) have zeros after ones."
)

31 changes: 25 additions & 6 deletions pySEQ/helpers/_bootstrap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from concurrent.futures import ProcessPoolExecutor, as_completed
import polars as pl
import numpy as np
from tqdm import tqdm
import copy
import time
Expand All @@ -23,6 +24,19 @@ def _prepare_boot_data(self, data, boot_id):

return bootstrapped

def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs):
obj = copy.deepcopy(obj)
obj._rng = np.random.RandomState(seed + i) if seed is not None else np.random.RandomState()
obj.DT = _prepare_boot_data(obj, original_DT, i)

# Disable bootstrapping to prevent recursion
obj.bootstrap_nboot = 0

method = getattr(obj, method_name)
result = method(*args, **kwargs)
obj._rng = None
return result

def bootstrap_loop(method):
@wraps(method)
def wrapper(self, *args, **kwargs):
Expand All @@ -38,17 +52,22 @@ def wrapper(self, *args, **kwargs):
original_DT = self.DT
nboot = self.bootstrap_nboot
ncores = self.ncores
seed = getattr(self, "seed", None)
method_name = method.__name__

def _worker(i):
obj = copy.deepcopy(self)
obj.DT = _prepare_boot_data(obj, original_DT, i)
return method(obj, *args, **kwargs)

if getattr(self, "parallel", False):
original_rng = getattr(self, "_rng", None)
self._rng = None

with ProcessPoolExecutor(max_workers=ncores) as executor:
futures = [executor.submit(_worker, i) for i in range(nboot)]
futures = [
executor.submit(_bootstrap_worker, self, method_name, original_DT, i, seed, args, kwargs)
for i in range(nboot)
]
for j in tqdm(as_completed(futures), total=nboot, desc="Bootstrapping..."):
results.append(j.result())

self._rng = original_rng
else:
for i in tqdm(range(nboot), desc="Bootstrapping..."):
self.DT = _prepare_boot_data(self, original_DT, i)
Expand Down
40 changes: 37 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,30 @@ build-backend = "setuptools.build_meta"
name = "pySEQ"
version = "0.9.0"
description = "Sequentially Nested Target Trial Emulation"
authors = [{name = "Ryan ODea", email = "ryan.odea@psi.ch"}]
readme = "README.md"
license = {text = "MIT"}
keywords = ["causal inference", "sequential trial emulation", "target trial", "observational studies"]
requires-python = ">=3.10"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12"
]

authors = [
{name = "Ryan O'Dea", email = "ryan.odea@psi.ch"},
{name = "Alejandro Szmulewicz", email = "aszmulewicz@hsph.harvard.edu"},
{name = "Tom Palmer", email = "tom.palmer@bristol.ac.uk"},
{name = "Miguel Hernan", email = "mhernan@hsph.harvard.edu"},
]

maintainers = [
{name = "Ryan O'Dea", email = "ryan.odea@psi.ch"},
]

dependencies = [
"numpy",
"polars",
Expand All @@ -18,10 +40,22 @@ dependencies = [
"lifelines"
]

[tools.setuptools]
[project.urls]
Homepage = "https://github.com/CausalInference/pySEQ"
Repository = "https://github.com/CausalInference/pySEQ"
"Bug Tracker" = "https://github.com/CausalInference/pySEQ/issues"

"Ryan O'Dea (ORCID)" = "https://orcid.org/0009-0000-0103-9546"
"Alejandro Szmulewicz (ORCID)" = "https://orcid.org/0000-0002-2664-802X"
"Tom Palmer (ORCID)" = "https://orcid.org/0000-0003-4655-4511"
"Miguel Hernan (ORCID)" = "https://orcid.org/0000-0003-1619-8456"
"University of Bristol (ROR)" = "https://ror.org/0524sp257"
"Harvard University (ROR)" = "https://ror.org/03vek6s52"

[tool.setuptools]
packages = ["pySEQ", "pySEQ.data"]

[tools.setuptools.package-data]
[tool.setuptools.package-data]
SEQdata = ["data/*.csv"]

[tool.pytest.ini_options]
Expand Down
25 changes: 25 additions & 0 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pySEQ import SEQuential, SEQopts
from pySEQ.data import load_data
import pytest

def test_ITT_collector():
data = load_data("SEQdata")

s = SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="tx_init",
outcome_col="outcome",
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method = "ITT",
parameters=SEQopts()
)
s.expand()
s.fit()
collector = s.collect()
outcomes = collector.retrieve_data("unique_outcomes")
with pytest.raises(ValueError):
collector.retrieve_data("km_data")
Empty file removed tests/test_bootstrap.py
Empty file.
4 changes: 2 additions & 2 deletions tests/test_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_PreE_censoring_excused_covariates():
parameters=SEQopts(weighted=True,
weight_preexpansion=True,
excused=True,
excused_colnames=["ExcusedZero", "ExcusedOne"])
excused_colnames=["excusedZero", "excusedOne"])
)
assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq"
assert s.numerator is None
Expand All @@ -144,7 +144,7 @@ def test_PostE_censoring_excused_covariates():
method = "censoring",
parameters=SEQopts(weighted=True,
excused=True,
excused_colnames=["ExcusedZero", "ExcusedOne"])
excused_colnames=["excusedZero", "excusedOne"])
)
assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas"
assert s.numerator == "sex+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq"
Expand Down
27 changes: 27 additions & 0 deletions tests/test_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pySEQ import SEQuential, SEQopts
from pySEQ.data import load_data

def test_parallel_ITT():
data = load_data("SEQdata")

s = SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="tx_init",
outcome_col="outcome",
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method = "ITT",
parameters=SEQopts(parallel=True,
bootstrap_nboot=2)
)
s.expand()
s.bootstrap()
s.fit()
matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list()
assert matrix == [-6.828506035553407, 0.18935003090041902, 0.12717241010542563,
0.033715156987629266, -0.00014691202235029346, 0.044566165558944326,
0.0005787770439053261, 0.0032906669395295026, -0.01339242049205771,
0.20072409918428052]
Loading