Skip to content

LinearTrend curve workflow fails with scalar parameters dims but extra dims #1693

Open
@williambdean

Description

@williambdean

The prior, curve, plot workflow fails with scalar dims and extra dims specified on the LinearTrend component.

from pymc_marketing.mmm import LinearTrend

trend = LinearTrend(n_changepoints=4, dims=("geo", "product"))

# This is fine
coords = dict(geo=["A", "B"], product=["X", "Y", "Z"])
prior = trend.sample_prior(coords=coords)

# This fails...
curve = trend.sample_curve(prior)

trend.plot_curve(curve)

Here is the traceback:

Trackback

---------------------------------------------------------------
ValueError                    Traceback (most recent call last)
Cell In[7], line 2
      1 # This fails...
----> 2 curve = trend.sample_curve(prior)

File ~/GitHub/pymc-eco/pymc-marketing/pymc_marketing/mmm/linear_trend.py:400, in LinearTrend.sample_curve(self, parameters, max_value)
    398 with pm.Model(coords=coords):
    399     name = "trend"
--> 400     pm.Deterministic(
    401         name,
    402         self.apply(t),
    403         dims=("t", *cast(Dims, self.dims)),
    404     )
    406     return pm.sample_posterior_predictive(
    407         parameters,
    408         var_names=[name],
    409     ).posterior_predictive[name]

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/model/core.py:2263, in Deterministic(name, var, model, dims)
   2261 var = var.copy(model.name_for(name))
   2262 model.deterministics.append(var)
-> 2263 model.add_named_variable(var, dims)
   2265 from pymc.printing import str_for_potential_or_deterministic
   2267 var.str_repr = types.MethodType(
   2268     functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
   2269 )

File ~/mamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/model/core.py:1476, in Model.add_named_variable(self, var, dims)
   1474 for dim in dims:
   1475     if dim not in self.coords and dim is not None:
-> 1476         raise ValueError(f"Dimension {dim} is not specified in `coords`.")
   1477 if any(var.name == dim for dim in dims if dim is not None):
   1478     raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")

ValueError: Dimension geo is not specified in `coords`.

Metadata

Metadata

Assignees

No one assigned

    Labels

    MMMbugSomething isn't workinggood second issueBit more involved but still doable for newcomersmodel componentsRelated to the various model components

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions