Skip to content

Commit 05bff03

Browse files
pre-commit-ci[bot]DipayanDasgupta
authored andcommitted
feat: Add tooltip to Altair agent portrayal (#2795)
This feature adds a `tooltip` attribute to `AgentPortrayalStyle`, enabling agent-specific information to be displayed on hover in Altair-based visualizations. This commit addresses review feedback by: - Adding documentation to clarify the feature is Altair-only. - Raising a ValueError if tooltips are used with the Matplotlib backend. - Applying consistency, typo, and formatting fixes suggested by reviewers.
1 parent 2e1e9da commit 05bff03

File tree

5 files changed

+78
-52
lines changed

5 files changed

+78
-52
lines changed

mesa/examples/basic/boltzmann_wealth_model/app.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,23 @@ def agent_portrayal(agent):
3838
}
3939

4040

41+
def post_process(chart):
42+
"""Post-process the Altair chart to add a colorbar legend."""
43+
chart = chart.encode(
44+
color=alt.Color(
45+
"original_color:Q",
46+
scale=alt.Scale(scheme="viridis", domain=[0, 10]),
47+
legend=alt.Legend(
48+
title="Wealth",
49+
orient="right",
50+
type="gradient",
51+
gradientLength=200,
52+
),
53+
),
54+
)
55+
return chart
56+
57+
4158
model = BoltzmannWealth(50, 10, 10)
4259

4360
# The SpaceRenderer is responsible for drawing the model's space and agents.
@@ -47,13 +64,11 @@ def agent_portrayal(agent):
4764
renderer = SpaceRenderer(model, backend="altair")
4865
# Can customize the grid appearance.
4966
renderer.draw_structure(grid_color="black", grid_dash=[6, 2], grid_opacity=0.3)
50-
renderer.draw_agents(
51-
agent_portrayal=agent_portrayal,
52-
cmap="viridis",
53-
vmin=0,
54-
vmax=10,
55-
legend_title="Wealth",
56-
)
67+
renderer.draw_agents(agent_portrayal=agent_portrayal)
68+
# The post_process function is used to modify the Altair chart after it has been created.
69+
# It can be used to add legends, colorbars, or other visual elements.
70+
renderer.post_process = post_process
71+
5772

5873
# Creates a line plot component from the model's "Gini" datacollector.
5974
GiniPlot = make_plot_component("Gini")
@@ -67,4 +82,4 @@ def agent_portrayal(agent):
6782
model_params=model_params,
6883
name="Boltzmann Wealth Model",
6984
)
70-
page # noqa
85+
page # noqa

mesa/visualization/backends/altair_backend.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""Altair-based renderer for Mesa spaces.
2+
3+
This module provides an Altair-based renderer for visualizing Mesa model spaces,
4+
agents, and property layers with interactive charting capabilities.
5+
"""
6+
17
import warnings
28
from collections.abc import Callable
39
from dataclasses import fields
@@ -201,12 +207,21 @@ def collect_agent_data(
201207

202208
return final_data
203209

204-
205-
206210
def draw_agents(
207211
self, arguments, chart_width: int = 450, chart_height: int = 350, **kwargs
208212
):
209-
"""Draw agents using Altair backend."""
213+
"""Draw agents using Altair backend.
214+
215+
Args:
216+
arguments: Dictionary containing agent data arrays.
217+
chart_width: Width of the chart.
218+
chart_height: Height of the chart.
219+
**kwargs: Additional keyword arguments for customization.
220+
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
221+
222+
Returns:
223+
alt.Chart: The Altair chart representing the agents, or None if no agents.
224+
"""
210225
if arguments["loc"].size == 0:
211226
return None
212227

@@ -219,7 +234,8 @@ def draw_agents(
219234
"size": arguments["size"][i],
220235
"shape": arguments["shape"][i],
221236
"opacity": arguments["opacity"][i],
222-
"strokeWidth": arguments["strokeWidth"][i] / 10, # Scale for continuous domain
237+
"strokeWidth": arguments["strokeWidth"][i]
238+
/ 10, # Scale for continuous domain
223239
"original_color": arguments["color"][i],
224240
}
225241
# Add tooltip data if available
@@ -230,7 +246,11 @@ def draw_agents(
230246
# Determine fill and stroke colors
231247
if arguments["filled"][i]:
232248
record["viz_fill_color"] = arguments["color"][i]
233-
record["viz_stroke_color"] = arguments["stroke"][i] if isinstance(arguments["stroke"][i], str) else None
249+
record["viz_stroke_color"] = (
250+
arguments["stroke"][i]
251+
if isinstance(arguments["stroke"][i], str)
252+
else None
253+
)
234254
else:
235255
record["viz_fill_color"] = None
236256
record["viz_stroke_color"] = arguments["color"][i]
@@ -240,52 +260,32 @@ def draw_agents(
240260
df = pd.DataFrame(records)
241261

242262
# Ensure all columns that should be numeric are, handling potential Nones
243-
numeric_cols = ['x', 'y', 'size', 'opacity', 'strokeWidth', 'original_color']
263+
numeric_cols = ["x", "y", "size", "opacity", "strokeWidth", "original_color"]
244264
for col in numeric_cols:
245265
if col in df.columns:
246-
df[col] = pd.to_numeric(df[col], errors='coerce')
247-
266+
df[col] = pd.to_numeric(df[col], errors="coerce")
248267

249268
# Get tooltip keys from the first valid record
250269
tooltip_list = ["x", "y"]
251-
# This is the corrected line:
252270
if any(t is not None for t in arguments["tooltip"]):
253-
first_valid_tooltip = next((t for t in arguments["tooltip"] if t), None)
254-
if first_valid_tooltip:
255-
tooltip_list.extend(first_valid_tooltip.keys())
271+
first_valid_tooltip = next(
272+
(t for t in arguments["tooltip"] if t is not None), None
273+
)
274+
if first_valid_tooltip is not None:
275+
tooltip_list.extend(first_valid_tooltip.keys())
256276

257277
# Extract additional parameters from kwargs
258278
title = kwargs.pop("title", "")
259279
xlabel = kwargs.pop("xlabel", "")
260280
ylabel = kwargs.pop("ylabel", "")
261-
legend_title = kwargs.pop("legend_title", "Color")
262-
263-
# Handle custom colormapping
264-
cmap = kwargs.pop("cmap", "viridis")
265-
vmin = kwargs.pop("vmin", None)
266-
vmax = kwargs.pop("vmax", None)
281+
# FIXME: Add more parameters to kwargs
267282

268283
color_is_numeric = pd.api.types.is_numeric_dtype(df["original_color"])
269-
if color_is_numeric:
270-
color_min = vmin if vmin is not None else df["original_color"].min()
271-
color_max = vmax if vmax is not None else df["original_color"].max()
272-
273-
fill_encoding = alt.Fill(
274-
"original_color:Q",
275-
scale=alt.Scale(scheme=cmap, domain=[color_min, color_max]),
276-
legend=alt.Legend(
277-
title=legend_title,
278-
orient="right",
279-
type="gradient",
280-
gradientLength=200,
281-
),
282-
)
283-
else:
284-
fill_encoding = alt.Fill(
285-
"viz_fill_color:N",
286-
scale=None,
287-
title="Color",
288-
)
284+
fill_encoding = (
285+
alt.Fill("original_color:Q")
286+
if color_is_numeric
287+
else alt.Fill("viz_fill_color:N", scale=None, title="Color")
288+
)
289289

290290
# Determine space dimensions
291291
xmin, xmax, ymin, ymax = self.space_drawer.get_viz_limits()
@@ -316,10 +316,16 @@ def draw_agents(
316316
),
317317
title="Shape",
318318
),
319-
opacity=alt.Opacity("opacity:Q", title="Opacity", scale=alt.Scale(domain=[0, 1], range=[0, 1])),
319+
opacity=alt.Opacity(
320+
"opacity:Q",
321+
title="Opacity",
322+
scale=alt.Scale(domain=[0, 1], range=[0, 1]),
323+
),
320324
fill=fill_encoding,
321325
stroke=alt.Stroke("viz_stroke_color:N", scale=None),
322-
strokeWidth=alt.StrokeWidth("strokeWidth:Q", scale=alt.Scale(domain=[0, 1])),
326+
strokeWidth=alt.StrokeWidth(
327+
"strokeWidth:Q", scale=alt.Scale(domain=[0, 1])
328+
),
323329
tooltip=tooltip_list,
324330
)
325331
.properties(title=title, width=chart_width, height=chart_height)
@@ -431,4 +437,4 @@ def draw_propertylayer(
431437
main_charts.append(current_chart)
432438

433439
base = alt.layer(*main_charts).resolve_scale(color="independent")
434-
return base
440+
return base

mesa/visualization/backends/matplotlib_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
2828
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
2929

30-
3130
CORRECTION_FACTOR_MARKER_ZOOM = 0.01
3231

3332

@@ -141,6 +140,10 @@ def collect_agent_data(self, space, agent_portrayal, default_size=None):
141140
)
142141
else:
143142
aps = portray_input
143+
if aps.tooltip is not None:
144+
raise ValueError(
145+
"The 'tooltip' attribute in AgentPortrayalStyle is only supported by the Altair backend."
146+
)
144147
# Set defaults if not provided
145148
if aps.x is None and aps.y is None:
146149
aps.x, aps.y = self._get_agent_pos(agent, space)

mesa/visualization/components/portrayal_components.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class AgentPortrayalStyle:
5656
edgecolors: str | tuple | None = None
5757
linewidths: float | int | None = 1.0
5858
tooltip: dict | None = None
59+
"""A dictionary of data to display on hover. Note: This feature is only available with the Altair backend."""
5960

6061
def update(self, *updates_fields: tuple[str, Any]):
6162
"""Updates attributes from variable (field_name, new_value) tuple arguments.
@@ -92,7 +93,7 @@ class PropertyLayerStyle:
9293
(vmin, vmax), transparency (alpha) and colorbar visibility.
9394
9495
Note: vmin and vmax are the lower and upper bounds for the colorbar and the data is
95-
normalized between these values for color/colorbar rendering. If they are not
96+
normalized between these values for color/colormap rendering. If they are not
9697
declared the values are automatically determined from the data range.
9798
9899
Note: You can specify either a 'colormap' (for varying data) or a single
@@ -118,4 +119,4 @@ def __post_init__(self):
118119
if self.color is not None and self.colormap is not None:
119120
raise ValueError("Specify either 'color' or 'colormap', not both.")
120121
if self.color is None and self.colormap is None:
121-
raise ValueError("Specify one of 'color' or 'colormap'")
122+
raise ValueError("Specify one of 'color' or 'colormap'")

tests/test_backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def test_altair_backend_draw_agents():
248248
"color": np.array(["red", "blue"]),
249249
"filled": np.array([True, True]),
250250
"stroke": np.array(["black", "black"]),
251+
"tooltip": np.array([None, None]),
251252
}
252253
ab.space_drawer.get_viz_limits = MagicMock(return_value=(0, 10, 0, 10))
253254
assert ab.draw_agents(arguments) is not None

0 commit comments

Comments
 (0)