Skip to content

Introduction of Space Renderer #2803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mesa/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
from .components import make_plot_component, make_space_component
from .components.altair_components import make_space_altair
from .solara_viz import JupyterViz, SolaraViz
from .space_renderer import SpaceRenderer
from .user_param import Slider

__all__ = [
"CommandConsole",
"JupyterViz",
"Slider",
"SolaraViz",
"SpaceRenderer",
"draw_space",
"make_plot_component",
"make_space_altair",
Expand Down
118 changes: 113 additions & 5 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@

import asyncio
import inspect
import itertools
import threading
import time
import traceback
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Any, Literal

import reacton.core
import solara
Expand All @@ -39,6 +40,7 @@
from mesa.experimental.devs.simulator import Simulator
from mesa.mesa_logging import create_module_logger, function_logger
from mesa.visualization.command_console import CommandConsole
from mesa.visualization.space_renderer import SpaceRenderer
from mesa.visualization.user_param import Slider
from mesa.visualization.utils import force_update, update_counter

Expand All @@ -52,6 +54,7 @@
@function_logger(__name__)
def SolaraViz(
model: Model | solara.Reactive[Model],
renderer: SpaceRenderer,
components: list[reacton.core.Component]
| list[Callable[[Model], reacton.core.Component]]
| Literal["default"] = "default",
Expand All @@ -74,6 +77,7 @@ def SolaraViz(
model (Model | solara.Reactive[Model]): A Model instance or a reactive Model.
This is the main model to be visualized. If a non-reactive model is provided,
it will be converted to a reactive model.
renderer (SpaceRenderer): A SpaceRenderer instance to render the model's space.
components (list[solara.component] | Literal["default"], optional): List of solara
components or functions that return a solara component.
These components are used to render different parts of the model visualization.
Expand Down Expand Up @@ -124,11 +128,17 @@ def SolaraViz(
if not isinstance(model, solara.Reactive):
model = solara.use_reactive(model) # noqa: SH102, RUF100

# set up reactive model_parameters shared by ModelCreator and ModelController
# Set up reactive model_parameters shared by ModelCreator and ModelController
reactive_model_parameters = solara.use_reactive({})
reactive_play_interval = solara.use_reactive(play_interval)
reactive_render_interval = solara.use_reactive(render_interval)
reactive_use_threads = solara.use_reactive(use_threads)

# Make a copy of the components to avoid modifying the original list
display_components = list(components)
# Create space component based on the renderer
display_components.append(create_space_component(renderer))

with solara.AppBar():
solara.AppBarTitle(name if name else model.value.__class__.__name__)
solara.lab.ThemeToggle()
Expand Down Expand Up @@ -166,6 +176,7 @@ def set_reactive_use_threads(value):
if not isinstance(simulator, Simulator):
ModelController(
model,
renderer=renderer,
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
Expand All @@ -175,6 +186,7 @@ def set_reactive_use_threads(value):
SimulatorController(
model,
simulator,
renderer=renderer,
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
Expand All @@ -187,14 +199,100 @@ def set_reactive_use_threads(value):
with solara.Card("Information"):
ShowSteps(model.value)
if (
CommandConsole in components
CommandConsole in display_components
): # If command console in components show it in sidebar
components.remove(CommandConsole)
display_components.remove(CommandConsole)
additional_imports = console_kwargs.get("additional_imports", {})
with solara.Card("Command Console"):
CommandConsole(model.value, additional_imports=additional_imports)

ComponentsView(components, model.value)
# Render the main components view
ComponentsView(display_components, model.value)


def create_space_component(renderer: SpaceRenderer):
"""Create a space visualization component for the given renderer."""

def SpaceVisualizationComponent(model: Model):
"""Component that renders the model's space using the provided renderer."""
return SpaceRendererComponent(model, renderer)

return SpaceVisualizationComponent


@solara.component
def SpaceRendererComponent(
model: Model,
renderer: SpaceRenderer,
# FIXME: Manage dependencies properly
dependencies: list[Any] | None = None,
):
"""Render the space of a model using a SpaceRenderer.

Args:
model (Model): The model whose space is to be rendered.
renderer: A SpaceRenderer instance to render the model's space.
dependencies (list[any], optional): List of dependencies for the component.
"""
update_counter.get()

# update renderer's space according to the model's space/grid
renderer.space = getattr(model, "grid", getattr(model, "space", None))

if renderer.backend == "matplotlib":
# Clear the previous plotted data and agents
all_artists = [
renderer.ax.lines[:],
renderer.ax.collections[:],
renderer.ax.patches[:],
renderer.ax.images[:],
]
# Chain them together into a single iterable
for artist in itertools.chain.from_iterable(all_artists):
artist.remove()

# Draw the space structure if specified
if renderer.space_mesh:
renderer.draw_structure(**renderer.space_kwargs)

# Draw agents if specified
if renderer.agent_mesh:
renderer.draw_agents(
renderer.agent_portrayal, ax=None, **renderer.agent_kwargs
)

# Draw property layers if specified
if renderer.propertylayer_mesh:
_, cbar = renderer.draw_propertylayer(renderer.propertylayer_portrayal)
# Remove the newly generated colorbar to avoid duplication
if cbar is not None:
cbar.remove()

# Update the fig every time frame
if dependencies:
dependencies.append(update_counter.value)
else:
dependencies = [update_counter.value]

solara.FigureMatplotlib(
renderer.ax.get_figure(),
format="png",
bbox_inches="tight",
dependencies=dependencies,
)
return None
else:
if renderer.space_mesh:
renderer.draw_structure(**renderer.space_kwargs)
if renderer.agent_mesh:
renderer.draw_agents(
renderer.agent_portrayal, ax=None, **renderer.agent_kwargs
)
if renderer.propertylayer_mesh:
renderer.draw_propertylayer(renderer.propertylayer_portrayal)

solara.FigureAltair(renderer.canvas, on_click=None, on_hover=None)
return None


def _wrap_component(
Expand Down Expand Up @@ -244,6 +342,7 @@ def ComponentsView(
def ModelController(
model: solara.Reactive[Model],
*,
renderer: SpaceRenderer,
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
Expand All @@ -253,6 +352,7 @@ def ModelController(

Args:
model: Reactive model instance
renderer: SpaceRenderer instance to render the model's space.
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher value reduce update frequency.
Expand Down Expand Up @@ -331,6 +431,9 @@ def do_reset():
f"creating new {model.value.__class__} instance with {model_parameters.value}",
)
model.value = model.value = model.value.__class__(**model_parameters.value)
renderer.space = (
model.value.grid if hasattr(model.value, "grid") else model.value.space
)

@function_logger(__name__)
def do_play_pause():
Expand Down Expand Up @@ -360,6 +463,7 @@ def do_play_pause():
def SimulatorController(
model: solara.Reactive[Model],
simulator,
renderer: SpaceRenderer,
*,
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
Expand All @@ -371,6 +475,7 @@ def SimulatorController(
Args:
model: Reactive model instance
simulator: Simulator instance
renderer: SpaceRenderer instance to render the model's space.
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher values reduce update frequency.
Expand Down Expand Up @@ -453,6 +558,9 @@ def do_reset():
model.value = model.value = model.value.__class__(
simulator=simulator, **model_parameters.value
)
renderer.space = (
model.value.grid if hasattr(model.value, "grid") else model.value.space
)

def do_play_pause():
"""Toggle play/pause."""
Expand Down
Loading
Loading