Skip to content

Commit

Permalink
includes RMST, difference in RMST and confidence intervals
Browse files Browse the repository at this point in the history
Implements RMST, difference in RMST from the R package:
https://cran.r-project.org/package=survRM2

Includes RMST of one population (point estimate and confidence interval), difference in RMST of 2 populations (point estimate, p-value, and confidence intervals). Addresses CamDavidsonPilon#821
  • Loading branch information
bayesfactor committed May 22, 2023
1 parent 258fa4b commit e291356
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 7 deletions.
Binary file added lifelines/filename.joblib
Binary file not shown.
137 changes: 130 additions & 7 deletions lifelines/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,129 @@ def z(p):
)


def restricted_mean_survival_time(point_in_time, fitterA) -> pd.Series:
"""
Implements Restricted Mean Survival Time analysis on the population described in fitterA
Returns the RMST value for the population described by fitterA
https://cran.r-project.org/package=survRM2
Parameters
----------
point_in_time: float,
the point in time to analyze the survival curves at.
fitterA:
A lifelines univariate model fitted to the data. This can be a ``KaplanMeierFitter``, ``WeibullFitter``, etc.
Returns
-------
pd.Series
a pandas Series with the properties 'RMST', 'RMST_SE', 'RMST_VAR',
'RMST_LCI', 'RMST_UCI'
Examples
--------
.. code:: python
T1 = [4, 5, 7, 11, 14, 20, 8, 8]
E1 = [1, 1, 1, 1, 1, 1, 1, 1]
kmf1 = KaplanMeierFitter().fit(T1, E1)
from lifelines.statistics import restricted_mean_survival_time
results = restricted_mean_survival_time(12.0, kmf1)
results
"""
ft = pd.DataFrame({"time": fitterA.timeline})
data = pd.DataFrame({"time": np.array(fitterA.durations), "event": np.array(fitterA.event_observed)})
# print(data)
ft.index = ft.time
ft["n_risk"] = [sum(data.time >= i) for i in ft.index]
ft["surv"] = fitterA.survival_function_

n_event = data[data.event == 1].groupby("time").count()
n_event = n_event.join(ft["time"], how="right").drop("time", axis=1)
n_event.event.fillna(0, inplace=True)
# print(n_event)
idx = ft.time <= point_in_time

wk_time = sorted(ft.time[idx].index.tolist() + [point_in_time])
wk_surv = ft.surv[idx]
wk_n_risk = ft.n_risk[idx]
wk_n_event = n_event[ft.time <= point_in_time]
time_diff = np.diff([0] + wk_time)
areas = time_diff * ([1] + wk_surv.tolist())
rmst = sum(areas)

wk_var = wk_n_event.event / (wk_n_risk * (wk_n_risk - wk_n_event.event))
wk_var = wk_var.tolist() + [0]
rmst_var = sum((np.flip(areas[1:])).cumsum() ** 2 * np.flip(wk_var)[1:])
rmst_se = np.sqrt(rmst_var)
z = stats.norm.ppf(1 - fitterA.alpha / 2)
out = pd.Series(
{"RMST": rmst, "RMST_SE": rmst_se, "RMST_VAR": rmst_var, "RMST_LCI": rmst - z * rmst_se, "RMST_UCI": rmst + z * rmst_se}
)
return out


def difference_in_restricted_mean_survival_time(point_in_time, fitterA, fitterB) -> pd.Series:
"""
Returns difference in Restricted Mean Survival Time analysis on the populations described in fitterA and fitterB
https://cran.r-project.org/package=survRM2
Parameters
----------
point_in_time: float,
the point in time to analyze the survival curves at.
fitterA:
A lifelines univariate model fitted to the data from one population. This can be a ``KaplanMeierFitter``, ``WeibullFitter``, etc.
fitterB:
A lifelines univariate model fitted to the data from a comparison population. This can be a ``KaplanMeierFitter``, ``WeibullFitter``, etc.
Returns
-------
pd.Series
a pandas Series with the properties 'RMST_DIFF_A_B', 'RMST_DIFF_A_B_LCI',
'RMST_DIFF_A_B_UCI', 'RMST_DIFF_pval'
Examples
--------
.. code:: python
df = load_waltons()
ix = df["group"] == "miR-137"
kmf1 = KaplanMeierFitter().fit(df.loc[ix]["T"], df.loc[ix]["E"])
kmf2 = KaplanMeierFitter().fit(df.loc[~ix]["T"], df.loc[~ix]["E"])
from lifelines.statistics import difference_in_restricted_mean_survival_time
results = difference_in_restricted_mean_survival_time(12.0, kmf1, kmf2)
results
"""
wk0 = restricted_mean_survival_time(point_in_time, fitterA)
wk1 = restricted_mean_survival_time(point_in_time, fitterB)
alpha = fitterA.alpha

z = stats.norm.ppf(1 - alpha / 2)
rmst_diff_10 = wk1.RMST - wk0.RMST
rmst_diff_10_se = np.sqrt(wk1.RMST_VAR + wk0.RMST_VAR)
rmst_diff_10_lci = rmst_diff_10 - z * rmst_diff_10_se
rmst_diff_10_uci = rmst_diff_10 + z * rmst_diff_10_se
rmst_diff_pval = stats.norm.cdf(-np.abs(rmst_diff_10) / rmst_diff_10_se) * 2
string = "RMST_DIFF_A_B"
rmst_diff_result = pd.Series(
{
string: rmst_diff_10,
f"{string}_LCI": rmst_diff_10_lci,
f"{string}_UCI": rmst_diff_10_uci,
"RMST_DIFF_pval": rmst_diff_pval,
}
)
return rmst_diff_result


def survival_difference_at_fixed_point_in_time_test(point_in_time, fitterA, fitterB, **result_kwargs) -> StatisticalResult:
"""
Often analysts want to compare the survival-ness of groups at specific times, rather than comparing the entire survival curves against each other.
Expand Down Expand Up @@ -438,7 +561,7 @@ def survival_difference_at_fixed_point_in_time_test(point_in_time, fitterA, fitt
test_name="survival_difference_at_fixed_point_in_time_test",
fitterA=fitterA,
fitterB=fitterB,
**result_kwargs
**result_kwargs,
)


Expand All @@ -451,7 +574,7 @@ def logrank_test(
weights_A=None,
weights_B=None,
weightings=None,
**kwargs
**kwargs,
) -> StatisticalResult:
r"""
Measures and reports on whether two intensity processes are different. That is, given two
Expand Down Expand Up @@ -667,7 +790,7 @@ def pairwise_logrank_test(
t_0=t_0,
name=[(g1, g2)],
weightings=weightings,
**kwargs
**kwargs,
)

return result
Expand Down Expand Up @@ -835,7 +958,7 @@ def multivariate_logrank_test(
assert abs(Z_j.sum()) < 10e-8, "Sum is not zero." # this should move to a test eventually.

# compute covariance matrix
factor = (((n_i - d_i) / (n_i - 1)).replace([np.inf, np.nan], 1)) * d_i / n_i ** 2
factor = (((n_i - d_i) / (n_i - 1)).replace([np.inf, np.nan], 1)) * d_i / n_i**2
n_ij["_"] = n_i.values
V_ = (n_ij.mul(w_i, axis=0)).mul(np.sqrt(factor), axis="index").fillna(0) # weighted V_
V = -np.dot(V_.T, V_)
Expand Down Expand Up @@ -923,7 +1046,7 @@ def proportional_hazard_test(
def compute_statistic(times, resids, n_deaths):
demeaned_times = times - times.mean()
T = (demeaned_times.values[:, None] * resids.values).sum(0) ** 2 / (
n_deaths * (fitted_cox_model.standard_errors_ ** 2) * (demeaned_times ** 2).sum()
n_deaths * (fitted_cox_model.standard_errors_**2) * (demeaned_times**2).sum()
)
return T

Expand All @@ -947,7 +1070,7 @@ def compute_statistic(times, resids, n_deaths):
null_distribution="chi squared",
degrees_of_freedom=1,
model=str(fitted_cox_model),
**kwargs
**kwargs,
)

else:
Expand All @@ -970,6 +1093,6 @@ def compute_statistic(times, resids, n_deaths):
null_distribution="chi squared",
degrees_of_freedom=1,
model=str(fitted_cox_model),
**kwargs
**kwargs,
)
return result
25 changes: 25 additions & 0 deletions lifelines/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,31 @@ def test_proportional_hazard_test_with_list():
assert results.summary.shape[0] == 2 * 2


def test_restricted_mean_survival_time_nonparametric():
print("testing RMST")
df = load_waltons()
ix = df["group"] == "miR-137"
kmf1 = KaplanMeierFitter().fit(df.loc[ix]["T"], df.loc[ix]["E"])
result = stats.restricted_mean_survival_time(10, kmf1)
assert np.isclose(result.RMST, 9.794, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_SE, 0.123, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_LCI, 9.553, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_UCI, 10.036, rtol=1e-2, atol=1e-3)


def test_difference_in_restricted_mean_survival_time_nonparametric():
print("testing diff in RMST")
df = load_waltons()
ix = df["group"] == "miR-137"
kmf1 = KaplanMeierFitter().fit(df.loc[ix]["T"], df.loc[ix]["E"])
kmf2 = KaplanMeierFitter().fit(df.loc[~ix]["T"], df.loc[~ix]["E"])
result = stats.difference_in_restricted_mean_survival_time(10, kmf1, kmf2)
assert np.isclose(result.RMST_DIFF_A_B, 0.183, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_DIFF_A_B_LCI, -0.063, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_DIFF_A_B_UCI, 0.428, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_DIFF_pval, 0.145, rtol=1e-2, atol=1e-3)


def test_survival_difference_at_fixed_point_in_time_test_nonparametric():
df = load_waltons()
ix = df["group"] == "miR-137"
Expand Down

0 comments on commit e291356

Please sign in to comment.