|
5 | 5 | import logging |
6 | 6 | from abc import ABC, abstractmethod |
7 | 7 | from collections.abc import Callable |
| 8 | +from contextvars import ContextVar, Token |
8 | 9 | from typing import Any |
9 | 10 |
|
10 | 11 | from src.core.domain.chat import ToolCall |
@@ -279,31 +280,35 @@ def reset_renderer_registry() -> None: |
279 | 280 |
|
280 | 281 |
|
281 | 282 | # Context manager to temporarily override the renderer for a block of code |
282 | | -_override: str | None = None |
| 283 | +_override_var: ContextVar[str | None] = ContextVar( |
| 284 | + "tool_text_renderer_override", default=None |
| 285 | +) |
283 | 286 |
|
284 | 287 |
|
285 | 288 | class OverrideRenderer: |
286 | 289 | def __init__(self, renderer_name: str): |
287 | 290 | self.renderer_name = renderer_name |
288 | | - self.original_override = _override |
| 291 | + self._token: Token[str | None] | None = None |
289 | 292 |
|
290 | 293 | def __enter__(self) -> None: |
291 | | - global _override |
292 | | - _override = self.renderer_name |
| 294 | + self._token = _override_var.set(self.renderer_name) |
293 | 295 |
|
294 | 296 | def __exit__(self, exc_type: Any, _: Any, traceback: Any) -> None: |
295 | | - global _override |
296 | | - _override = self.original_override |
| 297 | + if self._token is not None: |
| 298 | + _override_var.reset(self._token) |
| 299 | + else: |
| 300 | + _override_var.set(None) |
297 | 301 |
|
298 | 302 |
|
299 | 303 | def render_tool_call(tool_call: ToolCall) -> str | None: |
300 | 304 | """Render a tool call using the currently active renderer.""" |
301 | | - renderer_name = _override or _renderer_registry.default_renderer |
| 305 | + current_override = _override_var.get() |
| 306 | + renderer_name = current_override or _renderer_registry.default_renderer |
302 | 307 | renderer = get_renderer(renderer_name) |
303 | 308 | text = renderer.render(tool_call) |
304 | 309 | if text: |
305 | 310 | return text |
306 | | - if (_override or "").strip().lower() in {"", "none"}: |
| 311 | + if (current_override or "").strip().lower() in {"", "none"}: |
307 | 312 | return None |
308 | 313 | fallback_name = _renderer_registry.fallback_renderer |
309 | 314 | if fallback_name and fallback_name != renderer_name: |
|
0 commit comments