Skip to content

Raise informative error message if MvNormal is instantiated without shape #4379

Closed
@MarcoGorelli

Description

@MarcoGorelli

Currently, we have

>>> pm.MvNormal.dist(np.ones(2), np.eye(2)).random()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/marco/pymc3-dev/pymc3/distributions/multivariate.py", line 273, in random
    param = np.broadcast_to(param, shape=output_shape + dist_shape[-1:])
  File "<__array_function__ internals>", line 5, in broadcast_to
  File "/home/marco/miniconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/numpy/lib/stride_tricks.py", line 180, in broadcast_to
    return _broadcast_to(array, shape, subok=subok, readonly=True)
  File "/home/marco/miniconda3/envs/pymc3-dev-py38/lib/python3.8/site-packages/numpy/lib/stride_tricks.py", line 118, in _broadcast_to
    raise ValueError('cannot broadcast a non-scalar to a scalar array')
ValueError: cannot broadcast a non-scalar to a scalar array

which isn't a very clear error message. It could be improved to raise an informative error message if shape isn't passed. Furthermore, the line

        vals = pm.MvNormal('vals', mu=mu, chol=chol, observed=data)

from its docstring could be amended to include shape


Suggested by @Sayam753 here pymc-devs/pymc-examples#11

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions