Skip to content

Multiprocessing failure: 'Can't pickle local object' #3034

Closed
@helmutsimon

Description

@helmutsimon

There have been previous issues raised regarding multiprocessing failure, e.g. #1033 (now closed), with some comments that the problem occurs in `more complex' cases.

The following code examples would appear to demonstrate that the one instance when the problem occurs is when:

1. Custom distributions are used via DensityDist.
2. The Model is called from within a function.

The model is a simple one: estimating the parameters of a multinomial distribution using a Dirichlet prior.

The following example (1) uses a custom distribution, but the model is not called from within a function. It works correctly.

import pymc3
from pymc3 import *
from pymc3.distributions.multivariate import Dirichlet, Multinomial
import numpy as np
from scipy.special import gammaln
from theano import config
import theano.tensor as tt
import warnings
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns


print('numpy version ', np.__version__)
print('PyMC3 version ', pymc3.__version__)
config.exception_verbosity='high'
config.warn.round = False
warnings.simplefilter("ignore", FutureWarning)
n = 4

with Model() as model1:
    prior = np.ones(n) / n
    
    def dirich_logpdf(value=prior):
        return -n * gammaln(1/n) + (-1 + 1/n) * tt.log(value).sum()
    
    stick = distributions.transforms.StickBreaking()
    probs = DensityDist('probs', dirich_logpdf, shape=n, testval=np.array(prior), transform=stick)
    data = np.array([5, 7, 1, 0])
    sfs_obs = Multinomial('sfs_obs', n=np.sum(data), p=probs, observed=data)
    
with model1:
    step = Metropolis()
    trace = sample(100000, tune=10000, step=step)

traceplot(trace, [probs]);
plt.show(forestplot(trace, varnames=['probs']))
print(summary(trace))

The following example (2) is the same as (1), except the model is called from within a function. It will fail unless njobs=1 is set in the sample statement.

import pymc3
from pymc3 import *
from pymc3.distributions.multivariate import Dirichlet, Multinomial
import numpy as np
from scipy.special import gammaln
from theano import config
import theano.tensor as tt
import warnings
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

def run_MCMC(n):
    print('numpy version ', np.__version__)
    print('PyMC3 version ', pymc3.__version__)
    config.exception_verbosity='high'
    config.warn.round = False
    warnings.simplefilter("ignore", FutureWarning)

    with Model() as model1:
        prior = np.ones(n) / n

        def dirich_logpdf(value=prior):
            return -n * gammaln(1/n) + (-1 + 1/n) * tt.log(value).sum()

        stick = distributions.transforms.StickBreaking()
        probs = DensityDist('probs', dirich_logpdf, shape=n, testval=np.array(prior), transform=stick)
        data = np.array([5, 7, 1, 0])
        sfs_obs = Multinomial('sfs_obs', n=np.sum(data), p=probs, observed=data)

    with model1:
        step = Metropolis()
        trace = sample(100000, tune=10000, step=step) #works with njobs=1

    traceplot(trace, [probs]);
    plt.show(forestplot(trace, varnames=['probs']))
    print(summary(trace))
    return trace
                          
                          
run_MCMC(4)

The following example (3) is the same as (2), except it uses the built-in Dirichlet distribution rather than a custom distribution. It works correctly.

import pymc3
from pymc3 import *
from pymc3.distributions.multivariate import Dirichlet, Multinomial
import numpy as np
from scipy.special import gammaln
from theano import config
import theano.tensor as tt
import warnings
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

def run_MCMC(n):
    print('numpy version ', np.__version__)
    print('PyMC3 version ', pymc3.__version__)
    config.exception_verbosity='high'
    config.warn.round = False
    warnings.simplefilter("ignore", FutureWarning)

    with Model() as model1:
        prior = np.ones(n) / n

        stick = distributions.transforms.StickBreaking()
        probs = Dirichlet('probs', a=prior)
        data = np.array([5, 7, 1, 0])
        sfs_obs = Multinomial('sfs_obs', n=np.sum(data), p=probs, observed=data)

    with model1:
        step = Metropolis()
        trace = sample(100000, tune=10000, step=step)

    traceplot(trace, [probs]);
    plt.show(forestplot(trace, varnames=['probs']))
    print(summary(trace))
    return trace
                          
                          
run_MCMC(4)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions