Skip to content

IndexError when using dims argument #193

@j-ros

Description

@j-ros

Describe the bug
Model sampling throwing IndexError when adding dims argument to BART instance, even with right dimension.

To Reproduce

import numpy as np
import pandas as pd
import pymc as pm
import pymc_bart as pmb
import pytensor.tensor as pt
from scipy.special import logit

# Read sample data
data_df = pd.read_csv(
    "https://raw.githubusercontent.com/juanitorduz/website_projects/master/data/retention_data.csv",
    parse_dates=["cohort", "period"],
)

# Processing data
eps = np.finfo(float).eps
train_data_red_df = data_df.query("cohort_age > 0").reset_index(drop=True)
train_obs_idx = train_data_red_df.index.to_numpy()
train_n_users = train_data_red_df["n_users"].to_numpy()
train_n_active_users = train_data_red_df["n_active_users"].to_numpy()
train_retention = train_data_red_df["retention"].to_numpy()
train_retention_logit = logit(train_retention + eps)
train_data_red_df["month"] = train_data_red_df["period"].dt.strftime("%m").astype(int)
features: list[str] = ["age", "cohort_age", "month"]
x_train = train_data_red_df[features]

# Model
with pm.Model(coords={"feature": features}) as model:
    # --- Data ---
    model.add_coord(name="obs", values=train_obs_idx, mutable=True)
    x = pm.MutableData(name="x", value=x_train, dims=("obs", "feature"))
    n_users = pm.MutableData(name="n_users", value=train_n_users, dims="obs")
    n_active_users = pm.MutableData(name="n_active_users", value=train_n_active_users, dims="obs")

    # --- Parametrization ---
    # The BART component models the image of the retention rate under the
    # logit transform so that the range is not constrained to [0, 1].
    mu = pmb.BART(
        name="mu",
        X=x,
        Y=train_retention_logit,
        dims="obs",
    )
    # We use the inverse logit transform to get the retention rate back into [0, 1].
    p = pm.Deterministic(name="p", var=pm.math.invlogit(mu), dims="obs")
    # We add a small epsilon to avoid numerical issues.
    p = pt.switch(pt.eq(p, 0), eps, p)
    p = pt.switch(pt.eq(p, 1), 1 - eps, p)

    # --- Likelihood ---
    n_active_users_estimated = pm.Binomial(
        name="n_active_users_estimated",
        n=n_users,
        p=p,
        observed=n_active_users,
        dims="obs",
    )

pm.model_to_graphviz(model=model)

# Fit model
with model:
    idata = pm.sample(draws=100, chains=1)
    posterior_predictive = pm.sample_posterior_predictive(trace=idata)

throws error

IndexError: tuple index out of range
Apply node that caused the error: BART_rv{"(i00,i01),(i10),(),(),(),(i50)->(o00)"}(RNG(<Generator(PCG64) at 0x1721017E0>), [], x, [-1.609437 ... .51268651], 100, 0.95, 2.0, [])
Toposort index: 0
Inputs types: [RandomGeneratorType, TensorType(int64, shape=(0,)), TensorType(float64, shape=(None, None)), TensorType(float64, shape=(1128,)), TensorType(int8, shape=()), TensorType(float64, shape=()), TensorType(float32, shape=()), TensorType(float64, shape=(0,))]
Inputs shapes: ['No shapes', (0,), (1128, 3), (1128,), (), (), (), (0,)]
Inputs strides: ['No strides', (0,), (8, 9024), (8,), (), (), (), (0,)]
Inputs values: [Generator(PCG64) at 0x1721017E0, array([], dtype=int64), 'not shown', 'not shown', array(100, dtype=int8), array(0.95), array(2., dtype=float32), array([], dtype=float64)]
Outputs clients: [[output[1](BART_rv{"(i00,i01),(i10),(),(),(),(i50)->(o00)"}.0)], [Second(mu, [-3.09309984])]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/opt/anaconda3/envs/pymc_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/opt/anaconda3/envs/pymc_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/opt/anaconda3/envs/pymc_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/sf/p5xckpjx73ld0k7wlwr9ftw40000gn/T/ipykernel_5378/3496386461.py", line 40, in <module>
    mu = pmb.BART(
  File "/opt/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc_bart/bart.py", line 173, in __new__
    return super().__new__(cls, name, *params, **kwargs)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/distributions/distribution.py", line 536, in __new__
    rv_out = cls.dist(*args, **kwargs)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc_bart/bart.py", line 177, in dist
    return super().dist(params, **kwargs)
  File "/opt/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/distributions/distribution.py", line 618, in dist
    rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

Inspecting the shapes of x (1128,3), train_retention_logit (1128,) and obs (1128,) everything seems correct. The error is not thrown if the dims argument is not passed to BART i.e.

mu = pmb.BART(
        name="mu",
        X=x,
        Y=train_retention_logit,
    )

Expected behavior
Sampling with no errors.

Additional context

python 3.10.15
pymc 5.16.2
pymc-bart 0.7.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions