Skip to content

Empirical distribution does not accept inference data #5838

Open
@ferrine

Description

@ferrine

Description of your problem

Can't initialize the Empirical approximation

with pm.Model() as model:
    v = pm.Normal("a")
    trace = pm.sample()
    emp = pm.Empirical(trace)

Please provide the full traceback.

Complete error traceback
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_2602197/4225585213.py in <cell line: 1>()
      2     v = pm.Normal("a")
      3     trace = pm.sample()
----> 4     emp = pm.Empirical(trace)

~/dev/pymc3/pymc/variational/approximations.py in __init__(self, trace, size, **kwargs)
    566         if kwargs.get("local_rv", None) is not None:
    567             raise opvi.LocalGroupError("Empirical approximation does not support local variables")
--> 568         super().__init__(trace=trace, size=size, **kwargs)
    569 
    570     def evaluate_over_trace(self, node):

~/dev/pymc3/pymc/variational/approximations.py in __init__(self, *args, **kwargs)
    526                 ]
    527             )
--> 528         super().__init__(groups, model=kwargs.get("model"))
    529 
    530     def __getattr__(self, item):

~/dev/pymc3/pymc/variational/opvi.py in __init__(self, groups, model)
   1354                 raise GroupError("No approximation is specified for the rest variables")
   1355             else:
-> 1356                 rest.__init_group__(unseen_free_RVs)
   1357                 self.groups.append(rest)
   1358         self.model = model

~/.miniconda3/envs/pymc/lib/python3.9/site-packages/aesara/configparser.py in res(*args, **kwargs)
     45         def res(*args, **kwargs):
     46             with self:
---> 47                 return f(*args, **kwargs)
     48 
     49         return res

~/dev/pymc3/pymc/variational/approximations.py in __init_group__(self, group)
    208     def __init_group__(self, group):
    209         super().__init_group__(group)
--> 210         self._check_trace()
    211         if not self._check_user_params(spec_kw=dict(s=-1)):
    212             self.shared_params = self.create_shared_params(

~/dev/pymc3/pymc/variational/approximations.py in _check_trace(self)
    240         trace = self._kwargs.get("trace", None)
    241         if trace is not None and not all(
--> 242             [self.model.rvs_to_values[var].name in trace.varnames for var in self.group]
    243         ):
    244             raise ValueError("trace has not all free RVs in the group")

~/dev/pymc3/pymc/variational/approximations.py in <listcomp>(.0)
    240         trace = self._kwargs.get("trace", None)
    241         if trace is not None and not all(
--> 242             [self.model.rvs_to_values[var].name in trace.varnames for var in self.group]
    243         ):
    244             raise ValueError("trace has not all free RVs in the group")

AttributeError: 'InferenceData' object has no attribute 'varnames'

Please provide any additional information below.

Versions and main components

  • PyMC/PyMC3 Version: '4.0.0b6'

Metadata

Metadata

Assignees

No one assigned

    Labels

    VIVariational Inference

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions