Skip to content

Commit 62c6d5c

Browse files
committed
Fix rewriting of Responses API instructions
1 parent 58a58b2 commit 62c6d5c

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed

src/core/app/middleware/content_rewriting_middleware.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,59 @@ def _rewrite_responses_input(self, payload: dict[str, Any]) -> bool:
111111

112112
return is_rewritten
113113

114+
def _rewrite_responses_instructions(self, payload: dict[str, Any]) -> bool:
115+
"""Rewrite OpenAI Responses API instructions in-place."""
116+
117+
instructions = payload.get("instructions")
118+
if instructions is None:
119+
return False
120+
121+
if isinstance(instructions, str):
122+
rewritten = self.rewriter.rewrite_prompt(instructions, "system")
123+
if rewritten != instructions:
124+
payload["instructions"] = rewritten
125+
return True
126+
return False
127+
128+
if isinstance(instructions, dict):
129+
text_value = instructions.get("text")
130+
if isinstance(text_value, str):
131+
rewritten_text = self.rewriter.rewrite_prompt(
132+
text_value, "system"
133+
)
134+
if rewritten_text != text_value:
135+
instructions["text"] = rewritten_text
136+
return True
137+
return False
138+
139+
if isinstance(instructions, list):
140+
is_rewritten = False
141+
for index, block in enumerate(instructions):
142+
if not isinstance(block, dict):
143+
if isinstance(block, str):
144+
rewritten_text = self.rewriter.rewrite_prompt(
145+
block, "system"
146+
)
147+
if rewritten_text != block:
148+
instructions[index] = rewritten_text
149+
is_rewritten = True
150+
continue
151+
152+
text_value = block.get("text")
153+
if not isinstance(text_value, str):
154+
continue
155+
156+
rewritten_text = self.rewriter.rewrite_prompt(
157+
text_value, "system"
158+
)
159+
if rewritten_text != text_value:
160+
block["text"] = rewritten_text
161+
is_rewritten = True
162+
163+
return is_rewritten
164+
165+
return False
166+
114167
def _rewrite_chat_response(self, payload: dict[str, Any]) -> bool:
115168
"""Rewrite OpenAI chat completion responses in-place."""
116169

@@ -234,6 +287,9 @@ async def dispatch(
234287
if self._rewrite_responses_input(data):
235288
is_rewritten = True
236289

290+
if self._rewrite_responses_instructions(data):
291+
is_rewritten = True
292+
237293
if is_rewritten:
238294
body_bytes = json.dumps(data).encode("utf-8")
239295
scope_for_next_call = dict(request.scope)

tests/integration/test_content_rewriting_middleware.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,121 @@ async def run_test():
244244

245245
asyncio.run(run_test())
246246

247+
def test_request_rewriting_responses_instructions_string(self):
248+
"""Ensure Responses API instructions strings are rewritten."""
249+
250+
rewriter = ContentRewriterService(config_path=self.test_config_dir)
251+
middleware = ContentRewritingMiddleware(app=None, rewriter=rewriter)
252+
253+
request_payload = {
254+
"instructions": "original system guidance", # matches rule in setUp
255+
"input": "user input",
256+
}
257+
258+
async def call_next(request):
259+
data = await request.json()
260+
self.assertEqual(data["instructions"], "rewritten system guidance")
261+
return Response(
262+
content=json.dumps({"ok": True}), media_type="application/json"
263+
)
264+
265+
async def receive():
266+
return {
267+
"type": "http.request",
268+
"body": json.dumps(request_payload).encode("utf-8"),
269+
"more_body": False,
270+
}
271+
272+
request = Request(
273+
{
274+
"type": "http",
275+
"method": "POST",
276+
"headers": Headers({"content-type": "application/json"}).raw,
277+
"http_version": "1.1",
278+
"server": ("testserver", 80),
279+
"client": ("testclient", 123),
280+
"scheme": "http",
281+
"root_path": "",
282+
"path": "/test",
283+
"raw_path": b"/test",
284+
"query_string": b"",
285+
},
286+
receive=receive,
287+
)
288+
289+
async def run_test():
290+
response = await middleware.dispatch(request, call_next)
291+
self.assertEqual(response.status_code, 200)
292+
293+
import asyncio
294+
295+
asyncio.run(run_test())
296+
297+
def test_request_rewriting_responses_instructions_blocks(self):
298+
"""Ensure Responses API instruction content blocks are rewritten."""
299+
300+
rewriter = ContentRewriterService(config_path=self.test_config_dir)
301+
middleware = ContentRewritingMiddleware(app=None, rewriter=rewriter)
302+
303+
request_payload = {
304+
"instructions": [
305+
{"type": "text", "text": "original system guidance"},
306+
{"type": "image", "image_url": {"url": "https://example.com"}},
307+
],
308+
"input": [
309+
{
310+
"role": "user",
311+
"content": [
312+
{"type": "text", "text": "This should remain unchanged."}
313+
],
314+
}
315+
],
316+
}
317+
318+
async def call_next(request):
319+
data = await request.json()
320+
instructions = data["instructions"]
321+
self.assertIsInstance(instructions, list)
322+
self.assertEqual(instructions[0]["text"], "rewritten system guidance")
323+
self.assertEqual(
324+
instructions[1]["image_url"]["url"], "https://example.com"
325+
)
326+
return Response(
327+
content=json.dumps({"ok": True}), media_type="application/json"
328+
)
329+
330+
async def receive():
331+
return {
332+
"type": "http.request",
333+
"body": json.dumps(request_payload).encode("utf-8"),
334+
"more_body": False,
335+
}
336+
337+
request = Request(
338+
{
339+
"type": "http",
340+
"method": "POST",
341+
"headers": Headers({"content-type": "application/json"}).raw,
342+
"http_version": "1.1",
343+
"server": ("testserver", 80),
344+
"client": ("testclient", 123),
345+
"scheme": "http",
346+
"root_path": "",
347+
"path": "/test",
348+
"raw_path": b"/test",
349+
"query_string": b"",
350+
},
351+
receive=receive,
352+
)
353+
354+
async def run_test():
355+
response = await middleware.dispatch(request, call_next)
356+
self.assertEqual(response.status_code, 200)
357+
358+
import asyncio
359+
360+
asyncio.run(run_test())
361+
247362
def test_outbound_prompt_rewriting(self):
248363
"""Verify that outbound prompts are rewritten correctly."""
249364

0 commit comments

Comments
 (0)