Monod, Micheli & Bhatt (2025). NeuralSurv: Deep Survival Analysis with Bayesian Uncertainty Quantification. arXiv. DOI
Imperial makes no representation or warranty about the accuracy or completeness of the data nor that the results will not constitute in infringement of third-party rights. Imperial accepts no liability or responsibility for any use which may be made of any results, for the results, nor for any reliance which may be placed on any such work or results.
@misc{monod2025neuralsurvdeepsurvivalanalysis,
title={NeuralSurv: Deep Survival Analysis with Bayesian Uncertainty Quantification},
author={Mélodie Monod and Alessandro Micheli and Samir Bhatt},
year={2025},
eprint={2505.11054},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.11054},
}- macOS or UNIX
- This release has been checked on Ubuntu 22.04.4 LTS and macOS Sonoma 14.1.2
Clone the repository. A yml file is provided and can be used to build a conda virtual environment containing all dependencies. Create the environment using:
cd neuralsurv
conda env create -f neuralsurv.ymlThe file template.py demonstrates all the steps shown below, providing a complete workflow from data preparation to model fitting, posterior sampling, visualization, and evaluation.
The NeuralSurv framework expects:
time_trainandtime_test: Event or censoring times as JAX arrays.event_trainandevent_test: Event indicators (1 if event occurred, 0 if censored) as JAX arrays.x_trainandx_test: Covariate features as JAX arrays.
Example using synthetic data:
key = jr.PRNGKey(12)
n_train, n_test, p = 10, 5, 3
time_train = jax.random.normal(jr.split(key)[0], (n_train,)) * 100 + 150
event_train = jax.random.bernoulli(jr.split(key)[1], 0.5, (n_train,)).astype(jnp.int32)
x_train = jax.random.normal(jr.split(key)[2], (n_train, p))
time_test = jax.random.normal(jr.split(key)[3], (n_test,)) * 100 + 150
event_test = jax.random.bernoulli(jr.split(key)[4], 0.5, (n_test,)).astype(jnp.int32)
x_test = jax.random.normal(jr.split(key)[5], (n_test, p))Don't forget to rescale your time so the start time is 0.
Set prior parameters for the Bayesian model:
alpha_prior = 1.0
beta_prior = 1.0
rho = jnp.float32(1.0)Control EM and CAVI optimization iterations:
max_iter_em = 200 # maximum iteration for the EM algorithm
max_iter_cavi = 200 # maximum iteration for the CAVI algorithm
num_points_integral_em = 1000 # Number of points in trapezoidal approx
num_points_integral_cavi = 1000 # Number of points in trapezoidal approxbatch_size = 1000n_hidden = 2
n_layers = 2
activation = jax.nn.reluNumber of posterior samples to draw:
num_samples = 1000from model.model import NMLP, MLP
model = NMLP(mlp_main=MLP(n_hidden=n_hidden, n_layers=n_layers, activation=activation))This is the model architecture described in the paper. You can use any other model in Jax
key, step_rng = jr.split(key)
model_params_init = model.init(step_rng, jnp.array([0]), jnp.zeros(p))from neuralsurv import NeuralSurv
neuralsurv = NeuralSurv.load_or_create(
model,
model_params_init,
alpha_prior,
beta_prior,
rho,
num_points_integral_em,
num_points_integral_cavi,
batch_size,
max_iter_em,
max_iter_cavi,
output_dir,
)neuralsurv.fit(time_train, event_train, x_train)After fitting the model, you can draw posterior samples:
key, step_rng = jr.split(key)
neuralsurv.get_posterior_samples(step_rng, num_samples)You can compute concordance index (c-index), Brier score, D-calibration and KM calibration:
neuralsurv.compute_evaluation_metrics(time_train, event_train, time_test, event_test, x_test, plot_dir)You can obtain posterior samples of the hazard and survival functions at new times on the test set with
time_max = max(time_train.max(), time_test.max())
delta_time = time_max / 20
num = int(time_max // delta_time) + 1
new_times = jnp.linspace(1e-6, time_max, num=num)
hazard_train = neuralsurv.predict_hazard_function(new_times, x_test)
surv_train = neuralsurv.predict_survival_function(new_times, x_test)Dimensions: individual, time, posterior sample.
Reproduce the benchmark results
In main_benchmark.py, specify the directory where the results will be stored (output_dir). For example,
output_dir = "/Users/melodiemonod/projects/2025/neuralsurv/benchmark"Run main_benchmark.py.
Reproduce results of experiment "Synthetic Data Experiment"
First, specify the following entries in config_experiment_1.json
- Dataset Directory (
dataset_dir): The directory where the repository is located + 'data/data_files'. - GPU name (
devices) and index (devices_index): The name and index of your GPU device.
For example,
"dataset_dir":"/home/mm3218/git/neuralsurv/data/data_files",
"devices": ["NVIDIA RTX A6000"],
"devices_index":"0"Second, specify the following directories at the top of the submit_job_experiment_1.sh file:
- Repository Directory (
INDIR): The directory where the repository is located. - Output Directory (
OUTDIR): The directory where the results will be stored.
INDIR="/home/mm3218/git/neuralsurv"
OUTDIR="/home/mm3218/projects/2025/neuralsurv"Third, open a terminal and navigate to the repository directory, then execute the submit_job_experiment_1.sh script:
cd neuralsurv
bash submit_job_experiment_1.shThe script will generate folders in the output directory, one for each experiment.
Go to the output directory, locate the experiment folder and navigate into it.
cd $OUTDIR
cd $DATE-synthetic_25Run NeuralSurv and obtain the evaluation metrics and predict the survival function:
bash $DATE-synthetic_25.shRepeat these steps for each experiment folder created in $OUTDIR.
To reproduce the figure and the table, run make_tables_figures/synthetic_figure.py and make_tables_figures/synthetic_table.py by specifying the correct date, dataset_name, jobid and jobid_neuralsurv.
Reproduce results of experiment "Real Survival Data Experiments"
First, specify the following entries in config_experiment_2.json
- Dataset Directory (
dataset_dir): The directory where the repository is located + 'data/data_files'. - GPU name (
devices) and index (`devices_index): The name and index of your GPU device.
For example,
"dataset_dir":"/home/mm3218/git/neuralsurv/data/data_files",
"devices": ["NVIDIA RTX A6000"],
"devices_index":"0"Second, specify the following directories at the top of the submit_job_experiment_2.sh file:
- Repository Directory (
INDIR): The directory where the repository is located. - Output Directory (
OUTDIR): The directory where the results will be stored.
INDIR="/home/mm3218/git/neuralsurv"
OUTDIR="/home/mm3218/projects/2025/neuralsurv"Third, open a terminal and navigate to the repository directory, then execute the submit_job_experiment_2.sh script:
cd neuralsurv
bash submit_job_experiment_2.shThe script will generate folders in the output directory, one for each experiment.
Go to the output directory, locate the experiment folder and navigate into it.
cd $OUTDIR
cd $DATE-colon_sub_125_fold_0_layers_2_hidden_16_reluRun NeuralSurv and obtain the evaluation metrics and predict the survival function:
bash $DATE-colon_sub_125_fold_0_layers_2_hidden_16_relu.shRepeat these steps for each fold of a dataset created in $OUTDIR.
To reproduce the table, run make_tables_figures/real_data_tables.py by specifying the correct date, dataset_name, jobid and base and suffix of the jobid_neuralsurv.