Skip to content

Commit 7105db6

Browse files
committed
Fix tool text renderer override isolation
1 parent c03159e commit 7105db6

File tree

2 files changed

+60
-8
lines changed

2 files changed

+60
-8
lines changed

src/core/services/tool_text_renderer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
from abc import ABC, abstractmethod
77
from collections.abc import Callable
8+
from contextvars import ContextVar, Token
89
from typing import Any
910

1011
from src.core.domain.chat import ToolCall
@@ -279,31 +280,35 @@ def reset_renderer_registry() -> None:
279280

280281

281282
# Context manager to temporarily override the renderer for a block of code
282-
_override: str | None = None
283+
_override_var: ContextVar[str | None] = ContextVar(
284+
"tool_text_renderer_override", default=None
285+
)
283286

284287

285288
class OverrideRenderer:
286289
def __init__(self, renderer_name: str):
287290
self.renderer_name = renderer_name
288-
self.original_override = _override
291+
self._token: Token[str | None] | None = None
289292

290293
def __enter__(self) -> None:
291-
global _override
292-
_override = self.renderer_name
294+
self._token = _override_var.set(self.renderer_name)
293295

294296
def __exit__(self, exc_type: Any, _: Any, traceback: Any) -> None:
295-
global _override
296-
_override = self.original_override
297+
if self._token is not None:
298+
_override_var.reset(self._token)
299+
else:
300+
_override_var.set(None)
297301

298302

299303
def render_tool_call(tool_call: ToolCall) -> str | None:
300304
"""Render a tool call using the currently active renderer."""
301-
renderer_name = _override or _renderer_registry.default_renderer
305+
current_override = _override_var.get()
306+
renderer_name = current_override or _renderer_registry.default_renderer
302307
renderer = get_renderer(renderer_name)
303308
text = renderer.render(tool_call)
304309
if text:
305310
return text
306-
if (_override or "").strip().lower() in {"", "none"}:
311+
if (current_override or "").strip().lower() in {"", "none"}:
307312
return None
308313
fallback_name = _renderer_registry.fallback_renderer
309314
if fallback_name and fallback_name != renderer_name:
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import asyncio
2+
import json
3+
4+
import pytest
5+
from src.core.domain.chat import FunctionCall, ToolCall
6+
from src.core.services.tool_text_renderer import (
7+
OverrideRenderer,
8+
render_tool_call,
9+
reset_renderer_registry,
10+
)
11+
12+
13+
@pytest.mark.asyncio
14+
async def test_override_is_session_isolated() -> None:
15+
"""Ensure renderer overrides do not leak across concurrent sessions."""
16+
reset_renderer_registry()
17+
tool_call = ToolCall(
18+
id="call-1",
19+
function=FunctionCall(
20+
name="shell",
21+
arguments=json.dumps({"command": ["echo", "hello"]}),
22+
),
23+
)
24+
25+
start_override = asyncio.Event()
26+
release_override = asyncio.Event()
27+
28+
async def session_with_override() -> str | None:
29+
with OverrideRenderer("markdown"):
30+
start_override.set()
31+
await release_override.wait()
32+
return render_tool_call(tool_call)
33+
34+
async def concurrent_session() -> str | None:
35+
await start_override.wait()
36+
result = render_tool_call(tool_call)
37+
release_override.set()
38+
return result
39+
40+
override_result, default_result = await asyncio.gather(
41+
session_with_override(),
42+
concurrent_session(),
43+
)
44+
45+
assert override_result is not None and "```bash" in override_result
46+
assert default_result is None
47+
assert render_tool_call(tool_call) is None

0 commit comments

Comments
 (0)