-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2accca9
commit 283cc10
Showing
3 changed files
with
289 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import numpy as np | ||
|
||
from pymc import Model | ||
from pymc.printing import str_for_dist, str_for_potential_or_deterministic | ||
from pytensor import Mode | ||
from pytensor.compile.sharedvalue import SharedVariable | ||
from pytensor.graph.type import Constant, Variable | ||
from rich.box import SIMPLE_HEAD | ||
from rich.table import Table | ||
|
||
|
||
def variable_expression( | ||
model: Model, | ||
var: Variable, | ||
truncate_deterministic: int | None, | ||
) -> str: | ||
"""Get the expression of a variable in a human-readable format.""" | ||
if var in model.data_vars: | ||
var_expr = "Data" | ||
elif var in model.deterministics: | ||
str_repr = str_for_potential_or_deterministic(var, dist_name="") | ||
_, var_expr = str_repr.split(" ~ ") | ||
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...)) | ||
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic: | ||
contents = var_expr[2:-1].split(", ") | ||
str_len = 0 | ||
for show_n, content in enumerate(contents): | ||
str_len += len(content) + 2 | ||
if str_len > truncate_deterministic: | ||
break | ||
var_expr = f"f({', '.join(contents[:show_n])}, ...)" | ||
elif var in model.potentials: | ||
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(" ~ ")[1] | ||
else: # basic_RVs | ||
var_expr = str_for_dist(var).split(" ~ ")[1] | ||
return var_expr | ||
|
||
|
||
def _extract_dim_value(var: SharedVariable | Constant) -> np.ndarray: | ||
if isinstance(var, SharedVariable): | ||
return var.get_value(borrow=True) | ||
else: | ||
return var.data | ||
|
||
|
||
def dims_expression(model: Model, var: Variable) -> str: | ||
"""Get the dimensions of a variable in a human-readable format.""" | ||
if (dims := model.named_vars_to_dims.get(var.name)) is not None: | ||
dim_sizes = {dim: _extract_dim_value(model.dim_lengths[dim]) for dim in dims} | ||
return " × ".join(f"{dim}[{dim_size}]" for dim, dim_size in dim_sizes.items()) | ||
else: | ||
dim_sizes = list(var.shape.eval(mode=Mode(linker="py", optimizer=None))) | ||
return f"[{', '.join(map(str, dim_sizes))}]" if dim_sizes else "" | ||
|
||
|
||
def model_parameter_count(model: Model) -> int: | ||
"""Count the number of parameters in the model.""" | ||
rv_shapes = model.eval_rv_shapes() # Includes transformed variables | ||
return np.sum([np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs]) | ||
|
||
|
||
def model_table( | ||
model: Model, | ||
*, | ||
split_groups: bool = True, | ||
truncate_deterministic: int | None = None, | ||
parameter_count: bool = True, | ||
) -> Table: | ||
"""Create a rich table with a summary of the model's variables and their expressions. | ||
Parameters | ||
---------- | ||
model : Model | ||
The PyMC model to summarize. | ||
split_groups : bool | ||
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs) | ||
will be separated by a section. | ||
truncate_deterministic : int | None | ||
If not None, truncate the expression of deterministic variables that go beyond this length. | ||
empty_dims : bool | ||
If True, show the dimensions of scalar variables as an empty list. | ||
parameter_count : bool | ||
If True, add a row with the total number of parameters in the model. | ||
Returns | ||
------- | ||
Table | ||
A rich table with the model's variables, their expressions and dims. | ||
Examples | ||
-------- | ||
.. code-block:: python | ||
import numpy as np | ||
import pymc as pm | ||
from pymc_experimental.printing import model_table | ||
coords = {"subject": range(20), "param": ["a", "b"]} | ||
with pm.Model(coords=coords) as m: | ||
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param")) | ||
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject") | ||
beta = pm.Normal("beta", mu=0, sigma=1, dims="param") | ||
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject") | ||
sigma = pm.HalfNormal("sigma", sigma=1) | ||
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject") | ||
table = model_table(m) | ||
table # Displays the following table in an interactive environment | ||
''' | ||
Variable Expression Dimensions | ||
───────────────────────────────────────────────────── | ||
x = Data subject[20] × param[2] | ||
y = Data subject[20] | ||
beta ~ Normal(0, 1) param[2] | ||
sigma ~ HalfNormal(0, 1) | ||
Parameter count = 3 | ||
mu = f(beta) subject[20] | ||
y_obs ~ Normal(mu, sigma) subject[20] | ||
''' | ||
Output can be explicitly rendered in a rich console or exported to text, html or svg. | ||
.. code-block:: python | ||
from rich.console import Console | ||
console = Console(record=True) | ||
console.print(table) | ||
text_export = console.export_text() | ||
html_export = console.export_html() | ||
svg_export = console.export_svg() | ||
""" | ||
table = Table( | ||
show_header=True, | ||
show_edge=False, | ||
box=SIMPLE_HEAD, | ||
highlight=False, | ||
collapse_padding=True, | ||
) | ||
table.add_column("Variable", justify="right") | ||
table.add_column("Expression", justify="left") | ||
table.add_column("Dimensions") | ||
|
||
if split_groups: | ||
groups = ( | ||
model.data_vars, | ||
model.free_RVs, | ||
model.deterministics, | ||
model.potentials, | ||
model.observed_RVs, | ||
) | ||
else: | ||
# Show variables in the order they were defined | ||
groups = (model.named_vars.values(),) | ||
|
||
for group in groups: | ||
if not group: | ||
continue | ||
|
||
for var in group: | ||
var_name = var.name | ||
sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]' | ||
var_expr = variable_expression(model, var, truncate_deterministic) | ||
dims_expr = dims_expression(model, var) | ||
if dims_expr == "[]": | ||
dims_expr = "" | ||
table.add_row(var_name + sep, var_expr, dims_expr) | ||
|
||
if parameter_count and (not split_groups or group == model.free_RVs): | ||
n_parameters = model_parameter_count(model) | ||
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]") | ||
|
||
table.add_section() | ||
|
||
return table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import numpy as np | ||
import pymc as pm | ||
|
||
from rich.console import Console | ||
|
||
from pymc_experimental.printing import model_table | ||
|
||
|
||
def get_text(table) -> str: | ||
console = Console(width=80) | ||
with console.capture() as capture: | ||
console.print(table) | ||
return capture.get() | ||
|
||
|
||
def test_model_table(): | ||
with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model: | ||
x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject")) | ||
y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject")) | ||
|
||
mu = pm.Normal("mu", mu=0, sigma=1) | ||
sigma = pm.HalfNormal("sigma", sigma=1) | ||
global_intercept = pm.Normal("global_intercept", mu=0, sigma=1) | ||
intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, shape=(20, 1)) | ||
beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject") | ||
|
||
mu_trial = pm.Deterministic( | ||
"mu_trial", | ||
global_intercept.squeeze() + intercept_subject + beta_subject * x_data, | ||
dims=["trial", "subject"], | ||
) | ||
noise = pm.Exponential("noise", lam=1) | ||
y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject")) | ||
|
||
pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject") | ||
|
||
table_txt = get_text(model_table(model)) | ||
expected = """ Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) [20, 1] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
noise ~ Exponential(f()) | ||
Parameter count = 44 | ||
mu_trial = f(intercept_subject, trial[6] × subject[20] | ||
beta_subject, | ||
global_intercept) | ||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
""" | ||
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()] | ||
|
||
table_txt = get_text(model_table(model, split_groups=False)) | ||
expected = """ Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) [20, 1] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
mu_trial = f(intercept_subject, trial[6] × subject[20] | ||
beta_subject, | ||
global_intercept) | ||
noise ~ Exponential(f()) | ||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
Parameter count = 44 | ||
""" | ||
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()] | ||
|
||
table_txt = get_text( | ||
model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False) | ||
) | ||
expected = """ Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) [20, 1] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
mu_trial = f(intercept_subject, ...) trial[6] × subject[20] | ||
noise ~ Exponential(f()) | ||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
""" | ||
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()] |