Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

includes RMST, difference in RMST and confidence intervals #1526

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added lifelines/filename.joblib
Binary file not shown.
174 changes: 163 additions & 11 deletions lifelines/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,162 @@ def z(p):
)


def restricted_mean_survival_time(point_in_time, fitterA) -> pd.Series:
"""
Implements Restricted Mean Survival Time analysis on the population described in fitterA
Returns the RMST value for the population described by fitterA
https://cran.r-project.org/package=survRM2

Parameters
----------
point_in_time: float,
the point in time to analyze the survival curves at.

fitterA:
A lifelines univariate model fitted to the data. This can be a ``KaplanMeierFitter``, ``WeibullFitter``, etc.

Returns
-------

pd.Series
a pandas Series with the properties 'RMST', 'RMST_SE', 'RMST_VAR',
'RMST_LCI', 'RMST_UCI'

Examples
--------
.. code:: python
T1 = [4, 5, 7, 11, 14, 20, 8, 8]
E1 = [1, 1, 1, 1, 1, 1, 1, 1]
kmf1 = KaplanMeierFitter().fit(T1, E1)

from lifelines.statistics import restricted_mean_survival_time
results = restricted_mean_survival_time(12.0, kmf1)

results
"""
ft = pd.DataFrame({"time": fitterA.timeline})
ft.index = ft.time
ft["n_risk"] = fitterA.event_table.at_risk
ft["surv"] = fitterA.survival_function_

n_event = pd.merge(fitterA.event_table.observed, ft["time"], how='right', left_index=True, right_index=True).drop('time', axis=1)

idx = ft.time <= point_in_time

wk_time = sorted(ft.time[idx].index.tolist() + [point_in_time])
wk_surv = ft.surv[idx]
wk_n_risk = ft.n_risk[idx]
wk_n_event = n_event[ft.time <= point_in_time]
time_diff = np.diff(wk_time)
areas = time_diff * wk_surv
rmst = sum(areas)

wk_var = wk_n_event.observed / (wk_n_risk * (wk_n_risk - wk_n_event.observed))
wk_var = wk_var.replace(np.inf, 0).tolist()[1:] + [0]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the part that fixes the NaN issue. By adding .replace(np.inf, 0), the confidence intervals are not NaN now.

rmst_var = sum((np.flip(areas.values[1:])).cumsum() ** 2 * np.flip(wk_var)[1:])
rmst_se = np.sqrt(rmst_var)
z = stats.norm.ppf(1 - fitterA.alpha / 2)
out = pd.Series(
{"RMST": rmst, "RMST_SE": rmst_se, "RMST_VAR": rmst_var, "RMST_LCI": rmst - z * rmst_se, "RMST_UCI": rmst + z * rmst_se}
)
return out


def difference_in_restricted_mean_survival_time(point_in_time, fitterA, fitterB) -> pd.Series:
"""
Returns difference in Restricted Mean Survival Time analysis on the populations described in fitterA and fitterB
https://cran.r-project.org/package=survRM2

Parameters
----------
point_in_time: float,
the point in time to analyze the survival curves at.

fitterA:
A lifelines univariate model fitted to the data from one population. This can be a ``KaplanMeierFitter``, ``WeibullFitter``, etc.

fitterB:
A lifelines univariate model fitted to the data from a comparison population. This can be a ``KaplanMeierFitter``, ``WeibullFitter``, etc.

Returns
-------

pd.Series
a pandas Series with the properties 'RMST_DIFF_A_B', 'RMST_DIFF_A_B_LCI',
'RMST_DIFF_A_B_UCI', 'RMST_DIFF_pval'

Examples
--------
.. code:: python
df = load_waltons()
ix = df["group"] == "miR-137"
kmf1 = KaplanMeierFitter().fit(df.loc[ix]["T"], df.loc[ix]["E"])
kmf2 = KaplanMeierFitter().fit(df.loc[~ix]["T"], df.loc[~ix]["E"])

from lifelines.statistics import difference_in_restricted_mean_survival_time
results = difference_in_restricted_mean_survival_time(12.0, kmf1, kmf2)

results
"""
#check point_in_time argument validity
time_max = max(fitterA.durations.max(), fitterB.durations.max())
time_lesser_max = min(fitterA.durations.max(), fitterB.durations.max())
statusA_max = fitterA.event_table.censored.iloc[-1] == 0
statusB_max = fitterB.event_table.censored.iloc[-1] == 0
#print(statusA_max, statusB_max)
#case 1: last event in both groups is not censored
if statusA_max and statusB_max:
if point_in_time is not None:
if point_in_time > time_max:
raise ValueError(f'the point in time needs to be shorter than or equal to the largest observed time on each of the two groups: {time_max}')
else:
point_in_time = time_max
#case 2: the last observed event in the shorter arm is observed, the last observed event in the longer arm is censored
if (statusA_max==0 and statusB_max == 1 and fitterA.durations.max() >= fitterB.durations.max()) or \
(statusA_max==1 and statusB_max == 0 and fitterB.durations.max() > fitterA.durations.max()):
if point_in_time is not None:
if point_in_time > time_max:
raise ValueError(f'The point_in_time needs to be shorter than or equal to the largest observed time on each of the two groups: {time_max}')
else:
point_in_time = time_max
#case 3: the last observed event in the shorter arm is censored, the last observed event in the longer arm is observed
if (statusA_max == 1 and statusB_max == 0 and fitterA.durations.max() >= fitterB.durations.max()) or \
(statusA_max == 0 and statusB_max == 1 and fitterB.durations.max() > fitterA.durations.max()):
if point_in_time is not None:
if point_in_time > time_lesser_max:
raise ValueError(f'The point in time needs to be shorter than or equal to the minimum of the largest observed time on each of the two groups: {time_lesser_max}')
else:
point_in_time = time_lesser_max
#case 4: the last event in both groups is censored
if (not statusA_max) and (not statusB_max):
if point_in_time is not None:
if point_in_time > time_lesser_max:
raise ValueError(f'the point in time needs to be shorter than or equal to the minimum of the largest observed time on each of the two groups: {time_lesser_max}')
else:
point_in_time = time_lesser_max

wk0 = restricted_mean_survival_time(point_in_time, fitterA)
wk1 = restricted_mean_survival_time(point_in_time, fitterB)
alpha = fitterA.alpha

z = stats.norm.ppf(1 - alpha / 2)
rmst_diff_10 = wk1.RMST - wk0.RMST
rmst_diff_10_se = np.sqrt(wk1.RMST_VAR + wk0.RMST_VAR)
rmst_diff_10_lci = rmst_diff_10 - z * rmst_diff_10_se
rmst_diff_10_uci = rmst_diff_10 + z * rmst_diff_10_se
rmst_diff_pval = stats.norm.cdf(-np.abs(rmst_diff_10) / rmst_diff_10_se) * 2
string = "RMST_DIFF_A_B"
rmst_diff_result = pd.Series(
{
string: rmst_diff_10,
f"{string}_LCI": rmst_diff_10_lci,
f"{string}_UCI": rmst_diff_10_uci,
"RMST_DIFF_pval": rmst_diff_pval,
}
)
return rmst_diff_result


def survival_difference_at_fixed_point_in_time_test(point_in_time, fitterA, fitterB, **result_kwargs) -> StatisticalResult:
"""
Often analysts want to compare the survival-ness of groups at specific times, rather than comparing the entire survival curves against each other.
Expand Down Expand Up @@ -438,7 +594,7 @@ def survival_difference_at_fixed_point_in_time_test(point_in_time, fitterA, fitt
test_name="survival_difference_at_fixed_point_in_time_test",
fitterA=fitterA,
fitterB=fitterB,
**result_kwargs
**result_kwargs,
)


Expand All @@ -451,7 +607,7 @@ def logrank_test(
weights_A=None,
weights_B=None,
weightings=None,
**kwargs
**kwargs,
) -> StatisticalResult:
r"""
Measures and reports on whether two intensity processes are different. That is, given two
Expand Down Expand Up @@ -667,16 +823,12 @@ def pairwise_logrank_test(
t_0=t_0,
name=[(g1, g2)],
weightings=weightings,
**kwargs
**kwargs,
)

return result


def difference_of_restricted_mean_survival_time_test(model1, model2, t):
pass


def multivariate_logrank_test(
event_durations, groups, event_observed=None, weights=None, t_0=-1, weightings=None, **kwargs
) -> StatisticalResult: # pylint: disable=too-many-locals
Expand Down Expand Up @@ -835,7 +987,7 @@ def multivariate_logrank_test(
assert abs(Z_j.sum()) < 10e-8, "Sum is not zero." # this should move to a test eventually.

# compute covariance matrix
factor = (((n_i - d_i) / (n_i - 1)).replace([np.inf, np.nan], 1)) * d_i / n_i ** 2
factor = (((n_i - d_i) / (n_i - 1)).replace([np.inf, np.nan], 1)) * d_i / n_i**2
n_ij["_"] = n_i.values
V_ = (n_ij.mul(w_i, axis=0)).mul(np.sqrt(factor), axis="index").fillna(0) # weighted V_
V = -np.dot(V_.T, V_)
Expand Down Expand Up @@ -923,7 +1075,7 @@ def proportional_hazard_test(
def compute_statistic(times, resids, n_deaths):
demeaned_times = times - times.mean()
T = (demeaned_times.values[:, None] * resids.values).sum(0) ** 2 / (
n_deaths * (fitted_cox_model.standard_errors_ ** 2) * (demeaned_times ** 2).sum()
n_deaths * (fitted_cox_model.standard_errors_**2) * (demeaned_times**2).sum()
)
return T

Expand All @@ -947,7 +1099,7 @@ def compute_statistic(times, resids, n_deaths):
null_distribution="chi squared",
degrees_of_freedom=1,
model=str(fitted_cox_model),
**kwargs
**kwargs,
)

else:
Expand All @@ -970,6 +1122,6 @@ def compute_statistic(times, resids, n_deaths):
null_distribution="chi squared",
degrees_of_freedom=1,
model=str(fitted_cox_model),
**kwargs
**kwargs,
)
return result
25 changes: 25 additions & 0 deletions lifelines/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,31 @@ def test_proportional_hazard_test_with_list():
assert results.summary.shape[0] == 2 * 2


def test_restricted_mean_survival_time_nonparametric():
print("testing RMST")
df = load_waltons()
ix = df["group"] == "miR-137"
kmf1 = KaplanMeierFitter().fit(df.loc[ix]["T"], df.loc[ix]["E"])
result = stats.restricted_mean_survival_time(10, kmf1)
assert np.isclose(result.RMST, 9.794, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_SE, 0.123, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_LCI, 9.553, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_UCI, 10.036, rtol=1e-2, atol=1e-3)


def test_difference_in_restricted_mean_survival_time_nonparametric():
print("testing diff in RMST")
df = load_waltons()
ix = df["group"] == "miR-137"
kmf1 = KaplanMeierFitter().fit(df.loc[ix]["T"], df.loc[ix]["E"])
kmf2 = KaplanMeierFitter().fit(df.loc[~ix]["T"], df.loc[~ix]["E"])
result = stats.difference_in_restricted_mean_survival_time(10, kmf1, kmf2)
assert np.isclose(result.RMST_DIFF_A_B, 0.183, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_DIFF_A_B_LCI, -0.063, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_DIFF_A_B_UCI, 0.428, rtol=1e-2, atol=1e-3)
assert np.isclose(result.RMST_DIFF_pval, 0.145, rtol=1e-2, atol=1e-3)


def test_survival_difference_at_fixed_point_in_time_test_nonparametric():
df = load_waltons()
ix = df["group"] == "miR-137"
Expand Down