Skip to content

Commit

Permalink
Merge branch 'vb'
Browse files Browse the repository at this point in the history
* vb:
  ENH: Enable Stan's variational algorithm for approximate posterior
  initial commit variational inference
  • Loading branch information
brian-lau committed May 13, 2017
2 parents 1d435f8 + febeb16 commit b2189ae
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 57 deletions.
2 changes: 1 addition & 1 deletion +mstan/parse_stan_params.m
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
function str = parse_stan_params(s,root)
import mstan.parse_stan_params

branch = {'sample' 'optimize' 'diagnose' 'static' 'nuts' 'nesterov' 'bfgs' 'lbfgs'};
branch = {'sample' 'optimize' 'variational' 'diagnose' 'static' 'nuts' 'nesterov' 'bfgs' 'lbfgs'};
if nargin == 2
branch = branch(~strcmp(branch,root));
fn = fieldnames(s);
Expand Down
3 changes: 1 addition & 2 deletions +mstan/stan_home.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@
% o some way to manage fileseparators?
function d = stan_home()

d = '/Users/brian/Downloads/cmdstan';
d = '/Users/brian/Documents/Code/cmdstan-2.15.0';
%d = 'C:\Users\brian\Downloads\cmdstan';
%d = '/Users/brian/Documents/Code/Stan_2.4.0';
26 changes: 26 additions & 0 deletions +mstan/stan_params.m
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,32 @@
'iter',{{{'numeric'} {'scalar','>',0}}},...
'save_iterations',{{{'logical'} {'scalar'}}});

params.variational = struct(...
'algorithm','meanfield',...
'iter',10000,...
'grad_samples',1,...
'elbo_samples',100,...
'eta',1,...
'adapt',struct(...
'engaged',true,...
'iter',50),...
'tol_rel_obj',0.01,...
'eval_elbo',100,...
'output_samples',1000);

valid.variational = struct(...
'algorithm',{{'meanfield' 'fullrank'}},...
'iter',{{{'numeric'} {'scalar','>',0}}},...
'grad_samples',{{{'numeric'} {'scalar','>',0}}},...
'elbo_samples',{{{'numeric'} {'scalar','>',0}}},...
'eta',{{{'numeric'} {'scalar','>',0}}},...
'adapt',struct(...
'engaged',{{{'logical'} {'scalar'}}},...
'iter',{{{'numeric'} {'scalar','>',0}}}),...
'tol_rel_obj',{{{'numeric'} {'scalar','>',0}}},...
'eval_elbo',{{{'numeric'} {'scalar','>',0}}},...
'output_samples',{{{'numeric'} {'scalar','>',0}}});

params.diagnose = struct(...
'test','gradient');
valid.diagnose = struct(...
Expand Down
58 changes: 58 additions & 0 deletions Examples/variational.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
schools_code = {
'data {'
' int<lower=0> J; // number of schools '
' real y[J]; // estimated treatment effects'
' real<lower=0> sigma[J]; // s.e. of effect estimates '
'}'
'parameters {'
' real mu; '
' real<lower=0> tau;'
' real eta[J];'
'}'
'transformed parameters {'
' real theta[J];'
' for (j in 1:J)'
' theta[j] <- mu + tau * eta[j];'
'}'
'model {'
' eta ~ normal(0, 1);'
' y ~ normal(theta, sigma);'
'}'
};

schools_dat = struct('J',8,...
'y',[28 8 -3 7 -1 1 18 12],...
'sigma',[15 10 16 11 9 11 10 18]);

model = StanModel('verbose',true,'model_code',schools_code,'data',schools_dat);
model.compile();

fit_vb = model.vb();

print(fit_vb);

% http://www.slideshare.net/yutakashino/automatic-variational-inference-in-stan-nips2015yomi20160120
% compare to slide 32
% Inference for Stan model: 8schools.
% 1 chains, each with iter=2000, warmup=0, thin=1;
% post-warmup draws per chain=2000, total post-warmup draws=2000.
%
% mean sd 2.5% 25% 50% 75% 97.5%
% mu 7.75 4.63 -1.46 4.78 7.73 10.83 16.88
% tau 4.61 3.73 0.87 2.17 3.61 5.83 14.79
% eta[1] 0.34 0.99 -1.70 -0.33 0.37 0.99 2.26
% eta[2] -0.10 0.87 -1.74 -0.68 -0.11 0.48 1.59
% eta[3] -0.28 0.93 -2.12 -0.91 -0.28 0.33 1.55
% eta[4] 0.00 0.84 -1.64 -0.55 0.00 0.55 1.65
% eta[5] -0.34 0.96 -2.27 -1.02 -0.32 0.32 1.47
% eta[6] -0.27 0.94 -2.14 -0.93 -0.24 0.36 1.50
% eta[7] 0.49 0.95 -1.47 -0.14 0.48 1.12 2.33
% eta[8] 0.03 1.00 -1.91 -0.65 0.00 0.70 2.03
% theta[1] 9.34 7.41 -4.07 4.84 9.09 13.21 23.96
% theta[2] 7.20 6.93 -5.85 3.12 7.13 11.24 20.15
% theta[3] 6.39 7.01 -8.45 2.37 6.75 10.54 19.78
% theta[4] 7.84 6.99 -5.38 3.61 7.71 11.93 21.65
% theta[5] 6.24 7.59 -9.41 2.29 6.52 10.91 20.04
% theta[6] 6.43 7.11 -8.70 2.52 6.82 10.84 18.97
% theta[7] 10.01 7.48 -4.31 5.51 9.35 13.70 26.54
% theta[8] 7.79 7.59 -7.57 3.62 7.76 11.88 22.41
30 changes: 22 additions & 8 deletions StanFit.m
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function check(self)
if ~isempty(self.processes)
if any([self.processes.running])
for i = 1:numel(self.processes)
if self.processes(i).running;
if self.processes(i).running
fprintf('%s \t %s\n',self.processes(i).id,self.processes(i).stdout{end});
end
end
Expand Down Expand Up @@ -228,12 +228,22 @@ function process_exit_success(self,src)
elseif strcmp(self.model.method,'sample')
[hdr,flatNames,flatSamples,pos] = mstan.read_stan_csv(...
self.output_file{ind},self.model.inc_warmup);
elseif strcmp(self.model.method,'variational')
[hdr,flatNames,flatSamples] = mstan.read_stan_csv(...
self.output_file{ind},true);
% lp__ is a legacy feature that is no longer used
temp = strcmp(flatNames,'lp__');
flatNames(temp) = [];
flatSamples(:,temp) = [];
end
[names,dims,samples] = mstan.parse_flat_samples(flatNames,flatSamples);

if strcmp(self.model.method,'optimize')
exp_warmup = 0;
exp_iter = 1;
elseif strcmp(self.model.method,'variational')
exp_warmup = 0;
exp_iter = size(flatSamples,1); % FIXME
else
% Account for thinning
if self.model.inc_warmup
Expand Down Expand Up @@ -282,11 +292,9 @@ function process_exit_failure(self,src)

function str = print(self,varargin)
% TODO:
% o this should allow multiple files and regexp.
% x this does not work when method=optim, should shortcut
%
% note that passing regexp through in the command does not work,
% need to implment search in matlab
% o this should allow regexp.
% passing regexp through in the command does not work,
% need to implment search in matlab
% TODO: allow print parameters
% FIXME: ugh, if multiple fits were done with same output names
% print will just give the results from the last one. should
Expand All @@ -312,11 +320,17 @@ function process_exit_failure(self,src)
file = p.Results.file;
end

if mstan.check_ver(self.model.stan_version,'2.8.0')
command = 'stansummary';
else
command = 'print';
end

if ischar(file)
command = [self.model.stan_home filesep 'bin/print --sig_figs='...
command = [self.model.stan_home filesep 'bin' filesep command ' --sig_figs='...
num2str(p.Results.sig_figs) ' ' file];
elseif iscell(file)
command = [self.model.stan_home filesep 'bin/print --sig_figs='...
command = [self.model.stan_home filesep 'bin' filesep command ' --sig_figs='...
num2str(p.Results.sig_figs) ' ' sprintf('%s ',file{:})];
end
p = processManager('command',command,...
Expand Down
120 changes: 85 additions & 35 deletions StanModel.m
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
% file - string, optional
% The string passed is the filename containing the Stan model.
% method - string, optional
% {'sample' 'optimize'}, default = 'sample'
% {'sample' 'optimize' 'variational'}, default = 'sample'
% model_code - string, optional
% String, or cell array of strings containing Stan model.
% Ignored if 'file' is passed in.
Expand Down Expand Up @@ -50,6 +50,7 @@
% algorithm - string, optional
% If method = 'sample', {'NUTS','HMC'}, default = 'NUTS'
% If method = 'optimize', {'BFGS','NESTEROV' 'NEWTON'}, default = 'BFGS'
% If method = 'variational', {'MEANFIELD','FULLRANK'}, default = 'MEANFIELD'
% sample_file - string, optional
% Name of file(s) where samples for all parameters are saved.
% Default = 'output.csv'.
Expand Down Expand Up @@ -172,7 +173,7 @@
p.addParamValue('model_code',{});
p.addParamValue('working_dir',pwd);
p.addParamValue('method','sample',@(x) any(strcmp(x,...
{'sample' 'optimize' 'diagnose'})));
{'sample' 'optimize' 'variational' 'diagnose'})));
p.addParamValue('chains',4);
p.addParamValue('sample_file','',@ischar);
p.addParamValue('verbose',false,@islogical);
Expand All @@ -188,6 +189,7 @@
'processManager (https://github.com/brian-lau/MatlabProcessManager) is required');
end

% TODO: move this into stan_version
count = 0;
while 1 % FIXME, occasionally stanc does not return version?
try
Expand Down Expand Up @@ -308,6 +310,7 @@ function set(self,varargin)
&& exist(fullfile(fa.Name,'bin'),'dir')
self.stan_home = fa.Name;
else
% TODO make this message more informative
error('StanModel:stan_home:InputFormat',...
'Does not look like a proper stan setup');
end
Expand Down Expand Up @@ -566,32 +569,43 @@ function select_file(self)

function set.algorithm(self,algorithm)
algorithm = lower(algorithm);
if strcmp(self.method,'sample')
if strcmp(algorithm,'hmc')
algorithm = 'static';
end
if any(strcmp(self.validators.sample.hmc.engine,algorithm))
self.params.sample.hmc.engine = algorithm;
else
error('StanModel:algorithm:InputFormat',...
'Unknown algorithm for sampler');
end
elseif strcmp(self.method,'optimize')
if any(strcmp(self.validators.optimize.algorithm,algorithm))
self.params.optimize.algorithm = algorithm;
else
error('StanModel:algorithm:InputFormat',...
'Unknown algorithm for optimizer');
end
switch lower(self.method)
case 'optimize'
if any(strcmp(self.validators.optimize.algorithm,algorithm))
self.params.optimize.algorithm = algorithm;
else
error('StanModel:algorithm:InputFormat',...
'Unknown algorithm for optimizer');
end
case 'sample'
if strcmp(algorithm,'hmc')
algorithm = 'static';
end
if any(strcmp(self.validators.sample.hmc.engine,algorithm))
self.params.sample.hmc.engine = algorithm;
else
error('StanModel:algorithm:InputFormat',...
'Unknown algorithm for sampler');
end
case 'variational'
if any(strcmp(self.validators.variational.algorithm,algorithm))
self.params.variational.algorithm = algorithm;
else
error('StanModel:algorithm:InputFormat',...
'Unknown algorithm for variational inference');
end
end
end

function algorithm = get.algorithm(self)
if strcmp(self.method,'sample')
algorithm = [self.params.sample.algorithm ':' ...
self.params.sample.hmc.engine];
elseif strcmp(self.method,'optimize')
algorithm = self.params.optimize.algorithm;
switch lower(self.method)
case 'optimize'
algorithm = self.params.optimize.algorithm;
case 'sample'
algorithm = [self.params.sample.algorithm ':' ...
self.params.sample.hmc.engine];
case 'variational'
algorithm = self.params.variational.algorithm;
end
end

Expand Down Expand Up @@ -632,15 +646,18 @@ function select_file(self)
end

function control = get.control(self)
if strcmp(self.method,'sample')
control = self.params.sample.adapt;
if strncmp(self.algorithm,'hmc',3)
control.metric = self.params.sample.hmc.metric;
control.stepsize = self.params.sample.hmc.stepsize;
control.stepsize_jitter = self.params.sample.hmc.stepsize_jitter;
end
elseif strcmp(self.method,'optimize')
control = [];
switch lower(self.method)
case 'optimize'
control = [];
case 'sample'
control = self.params.sample.adapt;
if strncmp(self.algorithm,'hmc',3)
control.metric = self.params.sample.hmc.metric;
control.stepsize = self.params.sample.hmc.stepsize;
control.stepsize_jitter = self.params.sample.hmc.stepsize_jitter;
end
case 'variational'
control = [];
end
end

Expand Down Expand Up @@ -707,7 +724,7 @@ function select_file(self)

function fit = sampling(self,varargin)
if nargout == 0
error('stan:sampling:OutputFormat',...
error('StanModel:sampling:OutputFormat',...
'Need to assign the fit to a variable');
end
self.set(varargin{:});
Expand Down Expand Up @@ -753,7 +770,7 @@ function select_file(self)

function fit = optimizing(self,varargin)
if nargout == 0
error('stan:optimizing:OutputFormat',...
error('StanModel:optimizing:OutputFormat',...
'Need to assign the fit to a variable');
end
self.set(varargin{:});
Expand Down Expand Up @@ -784,6 +801,39 @@ function select_file(self)
p.start();
end

function fit = vb(self,varargin)
if nargout == 0
error('StanModel:vb:OutputFormat',...
'Need to assign the fit to a variable');
end
self.set(varargin{:});
self.method = 'variational';
if ~self.is_compiled
if self.verbose
fprintf('We have to compile the model first...\n');
end
self.compile();
end

if self.verbose
fprintf('Stan is performing variational inference ...\n');
end

p = processManager('id',self.sample_file,...
'command',sprintf('%s',self.command{:}),...
'workingDir',self.working_dir,...
'wrap',100,...
'keepStdout',true,...
'pollInterval',1,...
'printStdout',self.verbose,...
'autoStart',false);

fit = StanFit('model',copy(self),'processes',p,...
'output_file',{fullfile(self.working_dir,self.sample_file)},...
'verbose',self.verbose);
p.start();
end

function diagnose(self)
error('not done');
end
Expand Down
4 changes: 3 additions & 1 deletion Tests/TestExtract.m
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ function test_extract_all(self)
lp__ = ss.lp__;

ver = fit.model.stan_version;
if mstan.check_ver(ver,'2.2.0')
if mstan.check_ver(ver,'2.15.0')
assertEqual(fieldnames(ss),{'lp__' 'accept_stat__' 'stepsize__' 'treedepth__' 'n_leapfrog__' 'divergent__' 'energy__' 'alpha' 'beta'}');
elseif mstan.check_ver(ver,'2.2.0')
assertEqual(fieldnames(ss),{'lp__' 'accept_stat__' 'stepsize__' 'treedepth__' 'n_leapfrog__' 'n_divergent__' 'alpha' 'beta'}');
elseif mstan.check_ver(ver,'2.1.0')
assertEqual(fieldnames(ss),{'lp__' 'accept_stat__' 'stepsize__' 'treedepth__' 'n_divergent__' 'alpha' 'beta'}');
Expand Down
3 changes: 2 additions & 1 deletion mcmc.m
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,14 @@ function append_helper(self,s,data,chain_ind)
end
end

% FIXME for INCLUDE WARMUP
if p.Results.permuted
% TODO: ability to return permuted samples when we have warmup?
out = struct;
for i = 1:numel(names)
temp = cat(1,self.samples.(names{i})); %VERTCAT???
sz = size(temp);
temp = temp(self.permute_index(1:sz(1)),:);
temp = temp(self.permute_index(1:max(sz)),:);
out.(names{i}) = reshape(temp,sz);
end
% TODO: check that this is expected behavior!!
Expand Down
Loading

0 comments on commit b2189ae

Please sign in to comment.