Skip to content

Getting Started

Brian Lau edited this page Mar 31, 2017 · 25 revisions

Prerequisites

MatlabStan has the following dependencies:

  1. CmdStan: 2.0.1 or greater
  2. MatlabProcessManager: 0.4.0 or greater

Installing CmdStan is covered in detail for different platforms in the CmdStan Manual. Note that CmdStan must be built, so it requires make and a C++ compiler.

MatlabProcessManager is two Matlab files; install by simply adding them to your Matlab path.

Installation

To install MatlabStan:

  1. Obtain a copy here or clone the repo.
  2. Add the resulting folder to your Matlab path. +mstan is a package folder that does not need to be added to the path, although its parent folder does.
  3. Edit the file stan_home.m in the +mstan directory to point to the parent folder of your CmdStan installation.

Optional

Aki Vehtari's package for Pareto smoothed importance sampling leave-one-out (PSIS-LOO) cross-validation is included in the psis directory. Add this to your Matlab path as well if you want to use it.

Installing Steve Eddins's linewrap function is useful for dealing with unwrapped messages. His xUnit test framework is required if you want to run the unit tests.

Using MatlabStan

Example: eight schools

This is a classic example from Section 5.5 of Gelman et al (2003). The following can be compared to the Rstan and Pystan versions.

schools_code = {
   'data {'
   '    int<lower=0> J; // number of schools '
   '    real y[J]; // estimated treatment effects'
   '    real<lower=0> sigma[J]; // s.e. of effect estimates '
   '}'
   'parameters {'
   '    real mu; '
   '    real<lower=0> tau;'
   '    real eta[J];'
   '}'
   'transformed parameters {'
   '    real theta[J];'
   '    for (j in 1:J)'
   '    theta[j] <- mu + tau * eta[j];'
   '}'
   'model {'
   '    eta ~ normal(0, 1);'
   '    y ~ normal(theta, sigma);'
   '}'
};
  
schools_dat = struct('J',8,...
                     'y',[28 8 -3 7 -1 1 18 12],...
                     'sigma',[15 10 16 11 9 11 10 18]);

fit = stan('model_code',schools_code,'data',schools_dat);

print(fit);

eta = fit.extract('permuted',true).eta;
mean(eta)

Stan models can also be defined using a file. For example, download the file eight_schools.stan into your working directory and use the following call:

fit1 = stan('file','eight_schools.stan','data',schools_dat,'iter',1000,'chains',4);

Once a model is fitted, we can reuse the result as an input to stan with other data or settings. This saves us from having to compile the C++ code again (see also here). For example, if we want to sample more iterations:

fit2 = stan('fit',fit1,'data',schools_dat,'iter',10000,'chains',4);

The stan function returns a StanFit object, which contains samples from the posterior distribution. StanFit objects possess a number of methods, including print, traceplot and extract. For example, a summary of the posterior samples as well as the log-posterior (which has the name lp__) is obtained using

print(fit2);

which should look something like this:

Inference for Stan model: eight_schools_model
4 chains: each with iter=(5000,5000,5000,5000); warmup=(0,0,0,0); thin=(1,1,1,1); 20000 iterations saved.
Warmup took (0.16, 0.18, 0.22, 0.20) seconds, 0.77 seconds total
Sampling took (0.29, 0.30, 0.30, 0.25) seconds, 1.1 seconds total
                    Mean     MCSE  StdDev        5%       50%    95%  N_Eff  N_Eff/s    R_hat
lp__            -4.8e+00  4.0e-02     2.6  -9.4e+00  -4.6e+00  -0.92   4364     3828  1.0e+00
accept_stat__    7.2e-01  7.0e-02    0.30   3.4e-02   8.5e-01    1.0     19       16  1.1e+00
stepsize__       4.7e-01  5.3e-02   0.075   3.7e-01   4.9e-01   0.58    2.0      1.8  1.5e+13
treedepth__      1.7e+00  1.8e-01    0.61   0.0e+00   2.0e+00    2.0     11     10.0  1.1e+00
n_divergent__    8.7e-03  1.7e-03   0.093   0.0e+00   0.0e+00   0.00   3095     2715  1.0e+00
mu               8.0e+00  9.4e-02     5.1  -9.5e-02   7.9e+00     16   2970     2605  1.0e+00
tau              6.7e+00  1.0e-01     5.5   5.3e-01   5.5e+00     17   2873     2520  1.0e+00
eta[1]           4.0e-01  9.1e-03    0.93  -1.2e+00   4.1e-01    1.9  10496     9206  1.0e+00
eta[2]          -2.7e-03  8.9e-03    0.87  -1.4e+00  -7.3e-03    1.4   9545     8372  1.0e+00
eta[3]          -2.0e-01  9.0e-03    0.93  -1.7e+00  -2.2e-01    1.4  10578     9279  1.0e+00
eta[4]          -3.2e-02  8.4e-03    0.89  -1.5e+00  -3.3e-02    1.4  11090     9727  1.0e+00
eta[5]          -3.5e-01  8.8e-03    0.87  -1.8e+00  -3.7e-01    1.1   9694     8503  1.0e+00
eta[6]          -2.2e-01  9.1e-03    0.90  -1.6e+00  -2.4e-01    1.3   9764     8564  1.0e+00
eta[7]           3.5e-01  8.7e-03    0.87  -1.1e+00   3.7e-01    1.7  10018     8787  1.0e+00
eta[8]           5.5e-02  8.9e-03    0.93  -1.5e+00   5.6e-02    1.6  10972     9624  1.0e+00
theta[1]         1.1e+01  1.0e-01     8.3   2.2e-01   1.0e+01     27   6320     5544  1.0e+00
theta[2]         7.9e+00  6.3e-02     6.3  -2.2e+00   7.8e+00     18  10076     8838  1.0e+00
theta[3]         6.1e+00  8.9e-02     7.8  -7.3e+00   6.6e+00     18   7582     6651  1.0e+00
theta[4]         7.7e+00  6.6e-02     6.6  -3.2e+00   7.7e+00     19   9987     8760  1.0e+00
theta[5]         5.1e+00  6.4e-02     6.4  -6.2e+00   5.6e+00     15  10122     8878  1.0e+00
theta[6]         6.1e+00  6.9e-02     6.8  -5.7e+00   6.4e+00     17   9765     8566  1.0e+00
theta[7]         1.1e+01  7.7e-02     6.8   8.4e-01   1.0e+01     23   7782     6826  1.0e+00
theta[8]         8.6e+00  1.0e-01     8.2  -3.9e+00   8.3e+00     22   6297     5523  1.0e+00
Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at 
convergence, R_hat=1).

The extract method returns a struct or struct array for parameters of interest

% return a struct with all parameters when none specifically requested
la = fit.extract('permuted',true);  
mu = la.mu

% return an array with requested parameter
mu2 = fit.extract('pars','mu').mu;

% returns individual chains (each array element is a chain)
a = fit.extract('permuted',false);

Plotting traces is pretty basic at the moment

fit.traceplot;

Example: rats

Classic hierarchical normal model; a description and corresponding BUGS model can be found here. You can find the Stan model rats.stan. Fitting the model

y = [151, 145, 147, 155, 135, 159, 141, 159, 177, 134, ...
160, 143, 154, 171, 163, 160, 142, 156, 157, 152, 154, 139, 146, ...
157, 132, 160, 169, 157, 137, 153, 199, 199, 214, 200, 188, 210, ...
189, 201, 236, 182, 208, 188, 200, 221, 216, 207, 187, 203, 212, ...
203, 205, 190, 191, 211, 185, 207, 216, 205, 180, 200, 246, 249, ...
263, 237, 230, 252, 231, 248, 285, 220, 261, 220, 244, 270, 242, ...
248, 234, 243, 259, 246, 253, 225, 229, 250, 237, 257, 261, 248, ...
219, 244, 283, 293, 312, 272, 280, 298, 275, 297, 350, 260, 313, ...
273, 289, 326, 281, 288, 280, 283, 307, 286, 298, 267, 272, 285, ...
286, 303, 295, 289, 258, 286, 320, 354, 328, 297, 323, 331, 305, ...
338, 376, 296, 352, 314, 325, 358, 312, 324, 316, 317, 336, 321, ...
334, 302, 302, 323, 331, 345, 333, 316, 291, 324];
y = reshape(y,30,5);
x = [8 15 22 29 36];

rats_dat = struct('N',size(y,1),'TT',size(y,2),'x',x,'y',y,'xbar',mean(x));

rats_fit = stan('file','rats.stan','data',rats_dat,'verbose',true);
print(rats_fit);

should produce output that looks something like this (compare to Rstan run here):

Inference for Stan model: rats_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.
Warmup took (5.0, 6.6, 0.86, 0.67) seconds, 13 seconds total
Sampling took (0.28, 1.3, 0.38, 0.41) seconds, 2.4 seconds total
                Mean     MCSE  StdDev     5%   50%   95%  N_Eff  N_Eff/s    R_hat
lp__            -439  2.3e-01     7.0   -451  -438  -428    930      389  1.0e+00
accept_stat__   0.81  4.1e-02    0.18   0.45  0.86   1.0     19      8.0  1.1e+00
stepsize__      0.43  1.4e-01    0.20  0.078  0.54  0.58    2.0     0.84  4.7e+14
treedepth__      2.5  1.6e-02    0.99    2.0   2.0   5.0   4000     1672  4.8e+00
n_divergent__   0.00  0.0e+00    0.00   0.00  0.00  0.00   4000     1672      nan
alpha[1]         240  4.3e-02     2.7    235   240   244   4000     1672  1.0e+00
alpha[2]         248  4.3e-02     2.7    243   248   252   4000     1672  1.0e+00
alpha[3]         252  4.2e-02     2.7    248   252   257   4000     1672  1.0e+00
alpha[4]         233  4.3e-02     2.7    228   233   237   4000     1672  1.0e+00
alpha[5]         232  4.2e-02     2.6    227   232   236   4000     1672  1.0e+00
alpha[6]         250  4.2e-02     2.6    245   250   254   4000     1672  1.0e+00
alpha[7]         229  4.2e-02     2.7    224   229   233   4000     1672  1.0e+00
alpha[8]         248  4.2e-02     2.7    244   248   253   4000     1672  1.0e+00
alpha[9]         283  4.4e-02     2.8    279   283   288   4000     1672  1.0e+00
alpha[10]        219  4.5e-02     2.8    215   219   224   4000     1672  1.0e+00
alpha[11]        258  4.1e-02     2.6    254   258   263   4000     1672  1.0e+00
alpha[12]        228  4.3e-02     2.7    224   228   233   4000     1672  1.0e+00
alpha[13]        242  4.4e-02     2.8    238   242   247   4000     1672  1.0e+00
alpha[14]        268  4.3e-02     2.7    264   268   273   4000     1672  1.0e+00
alpha[15]        243  4.2e-02     2.7    238   243   247   4000     1672  1.0e+00
alpha[16]        245  4.2e-02     2.7    241   245   250   4000     1672  1.0e+00
alpha[17]        232  4.2e-02     2.7    228   232   236   4000     1672  1.0e+00
alpha[18]        240  4.2e-02     2.6    236   240   245   4000     1672  1.0e+00
alpha[19]        254  4.2e-02     2.6    249   254   258   4000     1672  1.0e+00
alpha[20]        242  4.2e-02     2.7    237   242   246   4000     1672  1.0e+00
alpha[21]        249  4.2e-02     2.6    244   249   253   4000     1672  1.0e+00
alpha[22]        225  4.2e-02     2.6    221   225   230   4000     1672  1.0e+00
alpha[23]        228  4.3e-02     2.7    224   229   233   4000     1672  1.0e+00
alpha[24]        245  4.2e-02     2.6    241   245   249   4000     1672  1.0e+00
alpha[25]        235  4.3e-02     2.7    230   235   239   4000     1672  1.0e+00
alpha[26]        254  4.3e-02     2.7    249   254   258   4000     1672  1.0e+00
alpha[27]        254  4.2e-02     2.7    250   254   259   4000     1672  1.0e+00
alpha[28]        243  4.2e-02     2.7    238   243   247   4000     1672  1.0e+00
alpha[29]        218  4.4e-02     2.8    213   218   222   4000     1672  1.0e+00
alpha[30]        241  4.2e-02     2.7    237   241   246   4000     1672  1.0e+00
beta[1]          6.1  3.8e-03    0.24    5.7   6.1   6.5   4000     1672  1.0e+00
beta[2]          7.1  4.1e-03    0.26    6.6   7.1   7.5   4000     1672  1.0e+00
beta[3]          6.5  3.9e-03    0.25    6.1   6.5   6.9   4000     1672  1.0e+00
beta[4]          5.3  4.2e-03    0.26    4.9   5.3   5.8   4000     1672  1.0e+00
beta[5]          6.6  3.8e-03    0.24    6.2   6.6   7.0   4000     1672  1.0e+00
beta[6]          6.2  3.7e-03    0.23    5.8   6.2   6.6   4000     1672  1.0e+00
beta[7]          6.0  3.9e-03    0.25    5.6   6.0   6.4   4000     1672  1.0e+00
beta[8]          6.4  3.9e-03    0.24    6.0   6.4   6.8   4000     1672  1.0e+00
beta[9]          7.1  4.1e-03    0.26    6.6   7.1   7.5   4000     1672  1.0e+00
beta[10]         5.8  3.8e-03    0.24    5.5   5.8   6.2   4000     1672  1.0e+00
beta[11]         6.8  4.0e-03    0.25    6.4   6.8   7.2   4000     1672  1.0e+00
beta[12]         6.1  3.9e-03    0.25    5.7   6.1   6.5   4000     1672  1.0e+00
beta[13]         6.2  3.8e-03    0.24    5.8   6.2   6.6   4000     1672  1.0e+00
beta[14]         6.7  3.9e-03    0.25    6.3   6.7   7.1   4000     1672  1.0e+00
beta[15]         5.4  4.0e-03    0.25    5.0   5.4   5.8   4000     1672  1.0e+00
beta[16]         5.9  4.0e-03    0.26    5.5   5.9   6.3   4000     1672  1.0e+00
beta[17]         6.3  3.8e-03    0.24    5.9   6.3   6.7   4000     1672  1.0e+00
beta[18]         5.8  3.9e-03    0.25    5.4   5.8   6.2   4000     1672  1.0e+00
beta[19]         6.4  3.9e-03    0.25    6.0   6.4   6.8   4000     1672  1.0e+00
beta[20]         6.1  3.9e-03    0.25    5.6   6.1   6.5   4000     1672  1.0e+00
beta[21]         6.4  3.8e-03    0.24    6.0   6.4   6.8   4000     1672  1.0e+00
beta[22]         5.9  3.9e-03    0.25    5.5   5.9   6.3   4000     1672  1.0e+00
beta[23]         5.7  4.0e-03    0.25    5.3   5.7   6.2   4000     1672  1.0e+00
beta[24]         5.9  3.9e-03    0.25    5.5   5.9   6.3   4000     1672  1.0e+00
beta[25]         6.9  4.1e-03    0.26    6.5   6.9   7.3   4000     1672  1.0e+00
beta[26]         6.5  3.8e-03    0.24    6.2   6.5   6.9   4000     1672  1.0e+00
beta[27]         5.9  3.9e-03    0.25    5.5   5.9   6.3   4000     1672  1.0e+00
beta[28]         5.8  3.9e-03    0.25    5.4   5.8   6.3   4000     1672  1.0e+00
beta[29]         5.7  3.8e-03    0.24    5.3   5.7   6.1   4000     1672  1.0e+00
beta[30]         6.1  3.9e-03    0.25    5.7   6.1   6.5   4000     1672  1.0e+00
mu_alpha         242  4.4e-02     2.8    238   242   247   4000     1672  1.0e+00
mu_beta          6.2  1.7e-03    0.11    6.0   6.2   6.4   4000     1672  1.0e+00
sigmasq_y         37  1.3e-01     5.8     29    37    48   2125      888  1.0e+00
sigmasq_alpha    219  1.0e+00      65    137   207   342   4000     1672  1.0e+00
sigmasq_beta    0.28  1.9e-03    0.10   0.15  0.26  0.47   2995     1252  1.0e+00
sigma_y          6.1  1.0e-02    0.47    5.4   6.1   6.9   2124      888  1.0e+00
sigma_alpha       15  3.3e-02     2.1     12    14    18   4000     1672  1.0e+00
sigma_beta      0.52  1.7e-03   0.093   0.38  0.51  0.68   2838     1187  1.0e+00
alpha0           106  5.8e-02     3.6    100   106   112   4000     1672  1.0e+00
Samples were drawn using hmc with nuts.
For each parameter, N_Eff is a crude measure of effective sample size,
and R_hat is the potential scale reduction factor on split chains (at 
convergence, R_hat=1).