Skip to content

Commit 326702b

Browse files
authored
Merge pull request #562 from matdev83/codex/fix-bug-in-request/response-rewriting-mvce3a
Fix rewriting of OpenAI Responses API instructions
2 parents d89f274 + 62c6d5c commit 326702b

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed

src/core/app/middleware/content_rewriting_middleware.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,59 @@ def _rewrite_responses_input(self, payload: dict[str, Any]) -> bool:
165165

166166
return is_rewritten
167167

168+
def _rewrite_responses_instructions(self, payload: dict[str, Any]) -> bool:
169+
"""Rewrite OpenAI Responses API instructions in-place."""
170+
171+
instructions = payload.get("instructions")
172+
if instructions is None:
173+
return False
174+
175+
if isinstance(instructions, str):
176+
rewritten = self.rewriter.rewrite_prompt(instructions, "system")
177+
if rewritten != instructions:
178+
payload["instructions"] = rewritten
179+
return True
180+
return False
181+
182+
if isinstance(instructions, dict):
183+
text_value = instructions.get("text")
184+
if isinstance(text_value, str):
185+
rewritten_text = self.rewriter.rewrite_prompt(
186+
text_value, "system"
187+
)
188+
if rewritten_text != text_value:
189+
instructions["text"] = rewritten_text
190+
return True
191+
return False
192+
193+
if isinstance(instructions, list):
194+
is_rewritten = False
195+
for index, block in enumerate(instructions):
196+
if not isinstance(block, dict):
197+
if isinstance(block, str):
198+
rewritten_text = self.rewriter.rewrite_prompt(
199+
block, "system"
200+
)
201+
if rewritten_text != block:
202+
instructions[index] = rewritten_text
203+
is_rewritten = True
204+
continue
205+
206+
text_value = block.get("text")
207+
if not isinstance(text_value, str):
208+
continue
209+
210+
rewritten_text = self.rewriter.rewrite_prompt(
211+
text_value, "system"
212+
)
213+
if rewritten_text != text_value:
214+
block["text"] = rewritten_text
215+
is_rewritten = True
216+
217+
return is_rewritten
218+
219+
return False
220+
168221
def _rewrite_chat_response(self, payload: dict[str, Any]) -> bool:
169222
"""Rewrite OpenAI chat completion responses in-place."""
170223

@@ -304,6 +357,9 @@ async def dispatch(
304357
if self._rewrite_responses_input(data):
305358
is_rewritten = True
306359

360+
if self._rewrite_responses_instructions(data):
361+
is_rewritten = True
362+
307363
if is_rewritten:
308364
body_bytes = json.dumps(data).encode("utf-8")
309365
scope_for_next_call = dict(request.scope)

tests/integration/test_content_rewriting_middleware.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,121 @@ async def run_test():
255255

256256
asyncio.run(run_test())
257257

258+
def test_request_rewriting_responses_instructions_string(self):
259+
"""Ensure Responses API instructions strings are rewritten."""
260+
261+
rewriter = ContentRewriterService(config_path=self.test_config_dir)
262+
middleware = ContentRewritingMiddleware(app=None, rewriter=rewriter)
263+
264+
request_payload = {
265+
"instructions": "original system guidance", # matches rule in setUp
266+
"input": "user input",
267+
}
268+
269+
async def call_next(request):
270+
data = await request.json()
271+
self.assertEqual(data["instructions"], "rewritten system guidance")
272+
return Response(
273+
content=json.dumps({"ok": True}), media_type="application/json"
274+
)
275+
276+
async def receive():
277+
return {
278+
"type": "http.request",
279+
"body": json.dumps(request_payload).encode("utf-8"),
280+
"more_body": False,
281+
}
282+
283+
request = Request(
284+
{
285+
"type": "http",
286+
"method": "POST",
287+
"headers": Headers({"content-type": "application/json"}).raw,
288+
"http_version": "1.1",
289+
"server": ("testserver", 80),
290+
"client": ("testclient", 123),
291+
"scheme": "http",
292+
"root_path": "",
293+
"path": "/test",
294+
"raw_path": b"/test",
295+
"query_string": b"",
296+
},
297+
receive=receive,
298+
)
299+
300+
async def run_test():
301+
response = await middleware.dispatch(request, call_next)
302+
self.assertEqual(response.status_code, 200)
303+
304+
import asyncio
305+
306+
asyncio.run(run_test())
307+
308+
def test_request_rewriting_responses_instructions_blocks(self):
309+
"""Ensure Responses API instruction content blocks are rewritten."""
310+
311+
rewriter = ContentRewriterService(config_path=self.test_config_dir)
312+
middleware = ContentRewritingMiddleware(app=None, rewriter=rewriter)
313+
314+
request_payload = {
315+
"instructions": [
316+
{"type": "text", "text": "original system guidance"},
317+
{"type": "image", "image_url": {"url": "https://example.com"}},
318+
],
319+
"input": [
320+
{
321+
"role": "user",
322+
"content": [
323+
{"type": "text", "text": "This should remain unchanged."}
324+
],
325+
}
326+
],
327+
}
328+
329+
async def call_next(request):
330+
data = await request.json()
331+
instructions = data["instructions"]
332+
self.assertIsInstance(instructions, list)
333+
self.assertEqual(instructions[0]["text"], "rewritten system guidance")
334+
self.assertEqual(
335+
instructions[1]["image_url"]["url"], "https://example.com"
336+
)
337+
return Response(
338+
content=json.dumps({"ok": True}), media_type="application/json"
339+
)
340+
341+
async def receive():
342+
return {
343+
"type": "http.request",
344+
"body": json.dumps(request_payload).encode("utf-8"),
345+
"more_body": False,
346+
}
347+
348+
request = Request(
349+
{
350+
"type": "http",
351+
"method": "POST",
352+
"headers": Headers({"content-type": "application/json"}).raw,
353+
"http_version": "1.1",
354+
"server": ("testserver", 80),
355+
"client": ("testclient", 123),
356+
"scheme": "http",
357+
"root_path": "",
358+
"path": "/test",
359+
"raw_path": b"/test",
360+
"query_string": b"",
361+
},
362+
receive=receive,
363+
)
364+
365+
async def run_test():
366+
response = await middleware.dispatch(request, call_next)
367+
self.assertEqual(response.status_code, 200)
368+
369+
import asyncio
370+
371+
asyncio.run(run_test())
372+
258373
def test_outbound_prompt_rewriting(self):
259374
"""Verify that outbound prompts are rewritten correctly."""
260375

0 commit comments

Comments
 (0)