Skip to content

sbieringer/csMALA

Repository files navigation

csMALA

This Repo contains the Numerics of our paper "Statistical guarantees for stochastic Metropolis-Hastings". This includes an implementation of a stochastic Metropolis Adjusted Langevin Algorithm (MALA) drawing auxiliary variables from a Bernoulli distribution $b_i \sim \mathrm{Ber}(\rho)$ for some $\rho\in (0,1]$ to construct batches of data.

Regression example

For a simple 1D Regression example $\mathcal{D}_n=(x_i,y_i)_{i\in 1,...,n}$, we use a $L_2$-loss

$$L_n(\vartheta, B) = L(\vartheta, B; \mathcal{D}_n) := \frac{1}{n\rho}\sum_{i=1}^n b_i \underbrace{(y_i - f_\vartheta(x_i))^2}_{=: l_i(\vartheta)}$$

at function parameters $\vartheta$. In combination with established MALA, this leads to sampling from a surrogate distribution, which only resembles the true posterior well for large $\rho$ and $n$.

We propose adapting the loss to

$$\tilde{L}_n(\vartheta, B) := \frac{1}{n}\sum_{i=1}^n b_i l_i(\vartheta) + \zeta \frac{\log{\rho}}{\lambda} \sum_{i=1}^n b_i\,,$$

which ensures sampling from a marginal invariant distribution with reduced inverse temperature of $\frac{\lambda}{1-\rho}$ for small $\frac{\lambda}{n}$ compared to the true posterior. The convergence of the resulting algorithm in Kullback-Leibner-Divergence does no longer depent on the batchsize $\rho$ and resembles that of full data MALA.

Scaling with the number of observations

Structure

  • src/MALA.py defines our MALA implementation and can by used in exchange for your usual PyTorch Optimizer
  • src/util contains the risk implementations, including our corrected risk term, as well as the model definition
  • src/datasets contains toy data und Bernoulli-sampling Dataloader
  • src/uncertimators contains wrappers around MALA and propabilistic training for easier use

Basic Usage

Minimal working examples of the data generation, Neural Network definition and weight sampling with MALA are provided in the first half ofStochasticMH.ipynb. The second half of the Notebook contains the plotting scripts for the figures displayed in the paper.

To generate the weight samples requiered for plotting StochasticMH.py and StochasticMH_Adam_baseline.py need to be executed for different $\rho$ and $n$. To this end both scripts accept arguments

  • folder: Path to the directory where the data is loaded from (if existent) and samples are saved.
  • n_points: Number of data points $n$.
  • rho: Batchsize $\rho$.
  • lambda_factor: Temperature parameter as specified in the paper.
  • sigma_data: Standard deviation of Gaussian noise added to the data in $y$-direction.
  • num_dataloader_workers: Number of threats used for dataloading. default=0

Citation

For more details, we refer to the paper

@unpublished{bieringer2023statistical,
  title={Statistical guarantees for stochastic {M}etropolis-{H}astings},
  author={Bieringer, Sebastian and Kasieczka, Gregor and Steffen, Maximilian F and Trabs, Mathias},
  eprint = "2310.09335",
  archivePrefix = "arXiv",
  primaryClass = "stat.ML",
  month={10},
  year={2023},
}.

Others

Building on this work we developed AdamMCMC. An MCMC tool for weight sampling in Neural Networks leveraging the advantages of the Adam optimizer.

About

"Statistical guarantees for stochastic Metropolis-Hastings"

Topics

Resources

Stars

Watchers

Forks