Description
There have been previous issues raised regarding multiprocessing failure, e.g. #1033 (now closed), with some comments that the problem occurs in `more complex' cases.
The following code examples would appear to demonstrate that the one instance when the problem occurs is when:
1. Custom distributions are used via DensityDist.
2. The Model is called from within a function.
The model is a simple one: estimating the parameters of a multinomial distribution using a Dirichlet prior.
The following example (1) uses a custom distribution, but the model is not called from within a function. It works correctly.
import pymc3
from pymc3 import *
from pymc3.distributions.multivariate import Dirichlet, Multinomial
import numpy as np
from scipy.special import gammaln
from theano import config
import theano.tensor as tt
import warnings
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
print('numpy version ', np.__version__)
print('PyMC3 version ', pymc3.__version__)
config.exception_verbosity='high'
config.warn.round = False
warnings.simplefilter("ignore", FutureWarning)
n = 4
with Model() as model1:
prior = np.ones(n) / n
def dirich_logpdf(value=prior):
return -n * gammaln(1/n) + (-1 + 1/n) * tt.log(value).sum()
stick = distributions.transforms.StickBreaking()
probs = DensityDist('probs', dirich_logpdf, shape=n, testval=np.array(prior), transform=stick)
data = np.array([5, 7, 1, 0])
sfs_obs = Multinomial('sfs_obs', n=np.sum(data), p=probs, observed=data)
with model1:
step = Metropolis()
trace = sample(100000, tune=10000, step=step)
traceplot(trace, [probs]);
plt.show(forestplot(trace, varnames=['probs']))
print(summary(trace))
The following example (2) is the same as (1), except the model is called from within a function. It will fail unless njobs=1
is set in the sample statement.
import pymc3
from pymc3 import *
from pymc3.distributions.multivariate import Dirichlet, Multinomial
import numpy as np
from scipy.special import gammaln
from theano import config
import theano.tensor as tt
import warnings
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
def run_MCMC(n):
print('numpy version ', np.__version__)
print('PyMC3 version ', pymc3.__version__)
config.exception_verbosity='high'
config.warn.round = False
warnings.simplefilter("ignore", FutureWarning)
with Model() as model1:
prior = np.ones(n) / n
def dirich_logpdf(value=prior):
return -n * gammaln(1/n) + (-1 + 1/n) * tt.log(value).sum()
stick = distributions.transforms.StickBreaking()
probs = DensityDist('probs', dirich_logpdf, shape=n, testval=np.array(prior), transform=stick)
data = np.array([5, 7, 1, 0])
sfs_obs = Multinomial('sfs_obs', n=np.sum(data), p=probs, observed=data)
with model1:
step = Metropolis()
trace = sample(100000, tune=10000, step=step) #works with njobs=1
traceplot(trace, [probs]);
plt.show(forestplot(trace, varnames=['probs']))
print(summary(trace))
return trace
run_MCMC(4)
The following example (3) is the same as (2), except it uses the built-in Dirichlet distribution rather than a custom distribution. It works correctly.
import pymc3
from pymc3 import *
from pymc3.distributions.multivariate import Dirichlet, Multinomial
import numpy as np
from scipy.special import gammaln
from theano import config
import theano.tensor as tt
import warnings
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
def run_MCMC(n):
print('numpy version ', np.__version__)
print('PyMC3 version ', pymc3.__version__)
config.exception_verbosity='high'
config.warn.round = False
warnings.simplefilter("ignore", FutureWarning)
with Model() as model1:
prior = np.ones(n) / n
stick = distributions.transforms.StickBreaking()
probs = Dirichlet('probs', a=prior)
data = np.array([5, 7, 1, 0])
sfs_obs = Multinomial('sfs_obs', n=np.sum(data), p=probs, observed=data)
with model1:
step = Metropolis()
trace = sample(100000, tune=10000, step=step)
traceplot(trace, [probs]);
plt.show(forestplot(trace, varnames=['probs']))
print(summary(trace))
return trace
run_MCMC(4)