-
-
Notifications
You must be signed in to change notification settings - Fork 996
Description
Issue Description
I have a little model here which has vectorized plate and markov context, and it gives strange error.
‘Tensor’ object has no attribute ‘_pyro_backward’
The strange thing is that if you delete any one of the a and b variable, the error disappear, which means that they worked well on their own and conflicted when working together, however, this two variables are set to be independent.
This model has no practical meanings, but it seems that this model doesn’t break any rules, so i guess that this model revealed some bugs.
I searched in the forum and found no relevant information, please help me! Thanks!
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import TraceEnum_ELBO, config_enumerate
@config_enumerate
def model():
with pyro.plate('plate', size=2, dim=-1):
a = pyro.sample('a', dist.Dirichlet(concentration=torch.ones(2)))
for T in pyro.markov(range(2)):
b = pyro.sample('b{}'.format(T), dist.Categorical(
probs=torch.ones(2)), infer={'enumerate': None})
@config_enumerate(default='sequential')
def guide():
with pyro.plate('plate', size=2, dim=-1):
a_param=pyro.param('a_param',torch.ones(2))
a = pyro.sample('a', dist.Dirichlet(concentration=a_param))
for T in pyro.markov(range(2)):
b_param=pyro.param('b_param',torch.ones(2))
b = pyro.sample('b{}'.format(T), dist.Categorical(probs=b_param))
pyro.clear_param_store()
elbo = TraceEnum_ELBO()
print(elbo.loss(model, guide))
The error is as follows
AttributeError Traceback (most recent call last) e:\Project\sim\sim_2p_2l.py in 27 pyro . clear_param_store ( ) 28 elbo = TraceEnum_ELBO ( ) —> 29 print ( elbo . loss ( generative_model , guide ) ) c:\Users\yaowang.conda\envs\my_env\lib\site-packages\pyro\infer\traceenum_elbo.py in loss **(self, model, guide, *args, kwargs) 403 elbo = 0.0 404 for model_trace , guide_trace in self . _get_traces ( model , guide , args , kwargs ) : → 405 elbo_particle = _compute_dice_elbo ( model_trace , guide_trace ) 406 if is_identically_zero ( elbo_particle ) : 407 continue c:\Users\yaowang.conda\envs\my_env\lib\site-packages\pyro\infer\traceenum_elbo.py in _compute_dice_elbo (model_trace, guide_trace) 212 costs . setdefault ( ordering [ name ] , [ ] ) . append ( cost ) 213 → 214 return Dice ( guide_trace , ordering ) . compute_expectation ( costs ) 215 216 c:\Users\yaowang.conda\envs\my_env\lib\site-packages\pyro\infer\util.py in compute_expectation (self, costs) 296 require_backward ( query ) 297 root = ring . sumproduct ( log_factors , sum_dims ) → 298 root . _pyro_backward ( ) 299 probs = { 300 key : query . _pyro_backward_result . exp ( ) AttributeError : ‘Tensor’ object has no attribute ‘_pyro_backward’
The even more funny thing is that if you change the sequence of the two variables, error message gone.
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import TraceEnum_ELBO, config_enumerate
@config_enumerate
def model():
for T in pyro.markov(range(2)):
b = pyro.sample('b{}'.format(T), dist.Categorical(
probs=torch.ones(2)), infer={'enumerate': None})
with pyro.plate('plate', size=2, dim=-1):
a = pyro.sample('a', dist.Dirichlet(concentration=torch.ones(2)))
@config_enumerate(default='sequential')
def guide():
for T in pyro.markov(range(2)):
b_param=pyro.param('b_param',torch.ones(2))
b = pyro.sample('b{}'.format(T), dist.Categorical(probs=b_param))
with pyro.plate('plate', size=2, dim=-1):
a_param=pyro.param('a_param',torch.ones(2))
a = pyro.sample('a', dist.Dirichlet(concentration=a_param))
pyro.clear_param_store()
elbo = TraceEnum_ELBO()
print(elbo.loss(model, guide))
It seems that this error was caused by a vectorized plate followed by at least two guide side sequentially enumerated sample site. Would someone please help me out here.
My intended use of this model is to have a global variable a, a local discrete variable b which depends on a, and another discrete variable c which depends on b. To simplify the problem, i removed the interdependencies in the above model with the error not changed. I know that under this case i should use guide side sequential enumeration on vatiable b, so there seems to be no way to work around.
Really appreciate any help here, i’m quite confused.