Skip to content

Avoiding recompilation in Stan models

Brian Lau edited this page Jun 3, 2017 · 3 revisions

Compiling models takes time so if the same model is going to be used repeatedly, we would like to compile it just once.

Within sessions you can avoid recompiling a model in two ways. The simplest method is to reuse a fit object in the call to stan:

model_code = {
'data {'
'    int<lower=0> N;'
'    int<lower=0,upper=1> y[N];'
'}'
'parameters {'
'    real<lower=0,upper=1> theta;'
'}'
'model {'
'for (n in 1:N)'
'    y[n] ~ bernoulli(theta);'
'}'
};

data = struct('N',10,'y',[0, 1, 0, 0, 0, 0, 0, 0, 0, 1]);

% This call will compile the model
fit = stan('model_code',model_code,'data',data);
print(fit);

new_data = struct('N',10,'y',[0, 1, 0, 1, 0, 1, 0, 1, 1, 1]);

% Passing in StanFit object skips recompilation
fit2 = stan('fit',fit,'data',new_data);
print(fit2);

Alternatively, we can create an StanModel instance and call the compile method first:

sm = StanModel('model_code',model_code);
sm.compile();

% subsequent calls will skip recompilation
fit3 = sm.sampling('data',data);
print(fit3);

fit4 = sm.sampling('data',data);
print(fit4);

This could for example be used to fit the same model to many data sets:

% True values for theta
theta = 0.1:.2:1;

% Set this to the number of cores of your machine to avoid warnings
ncores = 4;

for i = 1:numel(theta)
   % Generate some fake data
   data = struct('N',10,'y',double(rand(1,10)<theta(i)));
   
   % Sample, using the model compiled above for each data set
   fit(i) = stan('fit',sm,'data',data,'chains',min(i,ncores),'iter',150000);
   
   fprintf('Model: %s, id: %s, #chains=%g, seed=%g\n',...
      fit(i).model.model_name,fit(i).model.id,...
      fit(i).model.chains,fit(i).model.seed);
end

theta
% Run the following command after all fits are completed
arrayfun(@(x) mean(x.extract.theta),fit)