diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 93c8da05..18ddc24b 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -74,3 +74,12 @@ Model Transforms autoreparam.vip_reparametrize autoreparam.VIP + + +Printing +======== +.. currentmodule:: pymc_experimental.printing +.. autosummary:: + :toctree: generated/ + + model_table diff --git a/pymc_experimental/printing.py b/pymc_experimental/printing.py new file mode 100644 index 00000000..79a4eec0 --- /dev/null +++ b/pymc_experimental/printing.py @@ -0,0 +1,164 @@ +import itertools + +import numpy as np + +from pymc import Model +from pymc.printing import str_for_dist, str_for_potential_or_deterministic +from pytensor.compile.sharedvalue import SharedVariable +from pytensor.graph.type import Constant +from rich.box import SIMPLE_HEAD +from rich.table import Table + + +def _extract_value(var: SharedVariable | Constant) -> np.ndarray: + if isinstance(var, SharedVariable): + return var.get_value(borrow=True) + else: + return var.data + + +def model_table( + model: Model, + split_groups: bool = True, + truncate_deterministic: int | None = None, + parameter_count: bool = True, +) -> Table: + """Create a rich table with a summary of the model's variables and their expressions. + + Parameters + ---------- + model : Model + The PyMC model to summarize. + split_groups : bool + If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs) + will be separated by a section. + truncate_deterministic : int | None + If not None, truncate the expression of deterministic variables that go beyond this length. + parameter_count : bool + If True, add a row with the total number of parameters in the model. + + Returns + ------- + Table + A rich table with the model's variables, their expressions and dims. + + Examples + -------- + .. code-block:: python + + import numpy as np + import pymc as pm + + from pymc_experimental.printing import model_table + + coords = {"subject": range(20), "param": ["a", "b"]} + with pm.Model(coords=coords) as m: + x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param")) + y = pm.Data("y", np.random.normal(size=(20,)), dims="subject") + + beta = pm.Normal("beta", mu=0, sigma=1, dims="param") + mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject") + sigma = pm.HalfNormal("sigma", sigma=1) + + y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject") + + table = model_table(m) + table # Displays the following table in an interactive environment + ''' + Variable Expression Dimensions + ───────────────────────────────────────────────────── + x = Data subject[20] × param[2] + y = Data subject[20] + + beta ~ Normal(0, 1) param[2] + sigma ~ HalfNormal(0, 1) + Parameter count = 3 + + mu = f(beta) subject[20] + + y_obs ~ Normal(mu, sigma) subject[20] + ''' + + Output can be explicitly rendered in a rich console or exported to text, html or svg. + + .. code-block:: python + + from rich.console import Console + + console = Console(record=True) + console.print(table) + text_export = console.export_text() + html_export = console.export_html() + svg_export = console.export_svg() + + """ + table = Table( + show_header=True, + show_edge=False, + box=SIMPLE_HEAD, + highlight=False, + collapse_padding=True, + ) + table.add_column("Variable", justify="right") + table.add_column("Expression", justify="left") + table.add_column("Dimensions") + + dim_sizes = {k: _extract_value(v) for k, v in model.dim_lengths.items()} + + groups = ( + model.data_vars, + model.free_RVs, + model.deterministics, + model.potentials, + model.observed_RVs, + ) + if not split_groups: + groups = (itertools.chain.from_iterable(groups),) + + for group in groups: + if not group: + continue + + for var in group: + var_name = var.name + dims = model.named_vars_to_dims.get(var_name, ()) + + is_data = var in model.data_vars + is_deterministic = var in model.deterministics + is_potential = var in model.potentials + + if is_data: + var_expr = "Data" + elif is_deterministic: + str_repr = str_for_potential_or_deterministic(var, dist_name="") + _, var_expr = str_repr.split(" ~ ") + var_expr = var_expr[1:-1] # Remove outer parentheses (f(...)) + if truncate_deterministic is not None and len(var_expr) > truncate_deterministic: + contents = var_expr[2:-1].split(", ") + str_len = 0 + for show_n, content in enumerate(contents): + str_len += len(content) + 2 + if str_len > truncate_deterministic: + break + var_expr = f"f({', '.join(contents[:show_n])}, ...)" + elif is_potential: + var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split( + " ~ " + )[1] + else: + var_expr = str_for_dist(var).split(" ~ ")[1] + + dims_and_sizes = " × ".join(f"{dim}[{dim_sizes[dim]}]" for dim in dims) + sep = f'[b]{" =" if (is_data or is_deterministic or is_potential) else " ~"}[/b]' + table.add_row(var_name + sep, var_expr, dims_and_sizes) + + if parameter_count and (not split_groups or group == model.free_RVs): + rv_shapes = model.eval_rv_shapes() + n_parameters = np.sum( + [np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs] + ) + table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]") + + table.add_section() + + return table diff --git a/tests/test_printing.py b/tests/test_printing.py new file mode 100644 index 00000000..5ad23215 --- /dev/null +++ b/tests/test_printing.py @@ -0,0 +1,105 @@ +import io + +import numpy as np +import pymc as pm + +from rich.console import Console + +from pymc_experimental.printing import model_table + + +def get_text(table) -> str: + console = Console( + record=True, + file=io.StringIO(), + force_terminal=False, + force_interactive=False, + force_jupyter=False, + ) + console.print(table) + return console.export_text() + + +def test_model_table(): + with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model: + x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject")) + y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject")) + + mu = pm.Normal("mu", mu=0, sigma=1) + sigma = pm.HalfNormal("sigma", sigma=1) + global_intercept = pm.Normal("global_intercept", mu=0, sigma=1) + intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, dims="subject") + beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject") + + mu_trial = pm.Deterministic( + "mu_trial", + global_intercept + intercept_subject + beta_subject * x_data, + dims=["trial", "subject"], + ) + noise = pm.Exponential("noise", lam=1) + y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject")) + + pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject") + + table_txt = get_text(model_table(model)) + expected = """ Variable Expression Dimensions +──────────────────────────────────────────────────────────────────────────────── + x_data = Data trial[6] × subject[20] + y_data = Data trial[6] × subject[20] + + mu ~ Normal(0, 1) + sigma ~ HalfNormal(0, 1) + global_intercept ~ Normal(0, 1) + intercept_subject ~ Normal(0, 1) subject[20] + beta_subject ~ Normal(mu, sigma) subject[20] + noise ~ Exponential(f()) + Parameter count = 44 + + mu_trial = f(beta_subject, trial[6] × subject[20] + intercept_subject, + global_intercept) + + beta_subject_penalty = Potential(f(beta_subject)) subject[20] + + y ~ Normal(mu_trial, noise) trial[6] × subject[20] +""" + assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()] + + table_txt = get_text(model_table(model, split_groups=False)) + expected = """ Variable Expression Dimensions +──────────────────────────────────────────────────────────────────────────────── + x_data = Data trial[6] × subject[20] + y_data = Data trial[6] × subject[20] + mu ~ Normal(0, 1) + sigma ~ HalfNormal(0, 1) + global_intercept ~ Normal(0, 1) + intercept_subject ~ Normal(0, 1) subject[20] + beta_subject ~ Normal(mu, sigma) subject[20] + noise ~ Exponential(f()) + mu_trial = f(beta_subject, trial[6] × subject[20] + intercept_subject, + global_intercept) + beta_subject_penalty = Potential(f(beta_subject)) subject[20] + y ~ Normal(mu_trial, noise) trial[6] × subject[20] + Parameter count = 44 +""" + assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()] + + table_txt = get_text( + model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False) + ) + expected = """ Variable Expression Dimensions +──────────────────────────────────────────────────────────────────────────── + x_data = Data trial[6] × subject[20] + y_data = Data trial[6] × subject[20] + mu ~ Normal(0, 1) + sigma ~ HalfNormal(0, 1) + global_intercept ~ Normal(0, 1) + intercept_subject ~ Normal(0, 1) subject[20] + beta_subject ~ Normal(mu, sigma) subject[20] + noise ~ Exponential(f()) + mu_trial = f(beta_subject, ...) trial[6] × subject[20] + beta_subject_penalty = Potential(f(beta_subject)) subject[20] + y ~ Normal(mu_trial, noise) trial[6] × subject[20] +""" + assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]