Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pdd/llm_invoke.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Corrected code_under_test (llm_invoke.py)
# Added optional debugging prints in _select_model_candidates

import copy
import os
import pandas as pd
import litellm
Expand Down Expand Up @@ -1963,7 +1964,7 @@ def calc_strength(candidate):
# --- 5. Prepare LiteLLM Arguments ---
litellm_kwargs: Dict[str, Any] = {
"model": model_name_litellm,
"messages": formatted_messages,
"messages": copy.deepcopy(formatted_messages),
# Use a local adjustable temperature to allow provider-specific fallbacks
"temperature": current_temperature,
# Retry on transient network errors (APIError, TimeoutError, ServiceUnavailableError)
Expand Down
123 changes: 122 additions & 1 deletion tests/test_llm_invoke.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Corrected unit_test (tests/test_llm_invoke.py)

import copy
import pytest
import os
import pandas as pd
Expand Down Expand Up @@ -4755,4 +4756,124 @@ def capture_completion(**kwargs):

# time=None should be treated as 0, so no reasoning params
assert "thinking" not in captured_kwargs
assert "reasoning_effort" not in captured_kwargs
assert "reasoning_effort" not in captured_kwargs


# ---------------------------------------------------------------------------
# Issue #562: Groq structured output mutation regression tests
# ---------------------------------------------------------------------------

class _GroqMutationFixture:
"""Shared helpers for Groq message mutation tests."""

class SimpleResult(BaseModel):
answer: str
confidence: float

GROQ_SCHEMA_MARKER = "You must respond with valid JSON matching this schema"

@staticmethod
def _model(provider, model, elo, api_key):
return {
"provider": provider, "model": model, "input": 0.15,
"output": 0.60, "coding_arena_elo": elo,
"structured_output": True, "base_url": "", "api_key": api_key,
"max_tokens": "", "max_completion_tokens": "",
"reasoning_type": "none", "max_reasoning_tokens": 0,
}

@staticmethod
def _make_response(content):
resp = MagicMock()
choice = MagicMock()
choice.message.content = content
choice.finish_reason = "stop"
resp.choices = [choice]
usage = MagicMock()
usage.prompt_tokens = 100
usage.completion_tokens = 50
resp.usage = usage
return resp


class TestGroqMessageMutation(_GroqMutationFixture):
"""Regression tests for #562: Groq structured output must not corrupt
formatted_messages shared across fallback candidates."""

def _run(self, input_messages, candidates):
"""Call llm_invoke with Groq failing, capture every litellm.completion call."""
import pdd.llm_invoke as _llm_mod

groq = self._model("Groq", "groq/llama-3.3-70b-versatile", 1200, "GROQ_API_KEY")
openai = self._model("OpenAI", "gpt-4o-mini", 1100, "OPENAI_API_KEY")
models = [self._model(*c) for c in candidates] if candidates else [groq, openai]

captured = []

def side_effect(**kwargs):
captured.append(copy.deepcopy(kwargs))
if "groq/" in kwargs.get("model", ""):
raise Exception("Groq API error")
resp = self._make_response(json.dumps({"answer": "4", "confidence": 0.99}))
_llm_mod._LAST_CALLBACK_DATA["cost"] = 0.001
_llm_mod._LAST_CALLBACK_DATA["input_tokens"] = 100
_llm_mod._LAST_CALLBACK_DATA["output_tokens"] = 50
return resp

with patch.dict(os.environ, {"PDD_FORCE_LOCAL": "1",
"GROQ_API_KEY": "k", "OPENAI_API_KEY": "k"}), \
patch("pdd.llm_invoke._ensure_api_key", return_value=True), \
patch("pdd.llm_invoke._select_model_candidates", return_value=models), \
patch("pdd.llm_invoke._load_model_data", return_value=pd.DataFrame(models)), \
patch("pdd.llm_invoke.litellm") as mock_litellm:
mock_litellm.completion = MagicMock(side_effect=side_effect)
mock_litellm.cache = None
mock_litellm.drop_params = True
llm_invoke(
messages=input_messages, strength=0.5, temperature=0.0,
time=0.0, output_pydantic=self.SimpleResult, use_cloud=False,
)
return captured

def test_groq_fallback_no_schema_in_messages(self):
"""Core: after Groq fails, fallback gets clean messages (no JSON schema)."""
captured = self._run(
[{"role": "user", "content": "What is 2+2?"}],
[("Groq", "groq/llama-3.3-70b-versatile", 1200, "GROQ_API_KEY"),
("OpenAI", "gpt-4o-mini", 1100, "OPENAI_API_KEY")],
)
assert len(captured) >= 2
for msg in captured[1]["messages"]:
assert self.GROQ_SCHEMA_MARKER not in msg.get("content", ""), \
"Fallback model received Groq schema instruction"

def test_groq_fallback_system_message_not_mutated(self):
"""Dict overwrite: existing system message preserved for fallback."""
original = "You are a helpful math tutor."
captured = self._run(
[{"role": "system", "content": original},
{"role": "user", "content": "What is 2+2?"}],
[("Groq", "groq/llama-3.3-70b-versatile", 1200, "GROQ_API_KEY"),
("OpenAI", "gpt-4o-mini", 1100, "OPENAI_API_KEY")],
)
assert len(captured) >= 2
fallback_sys = captured[1]["messages"][0]
assert fallback_sys["content"] == original, \
f"System message mutated: {fallback_sys['content'][:200]}"

def test_groq_multiple_failures_no_cumulative_corruption(self):
"""Two Groq models fail; schema instructions must not accumulate."""
captured = self._run(
[{"role": "user", "content": "What is 2+2?"}],
[("Groq", "groq/llama-3.3-70b-versatile", 1200, "GROQ_API_KEY"),
("Groq", "groq/mixtral-8x7b-32768", 1150, "GROQ_API_KEY"),
("OpenAI", "gpt-4o-mini", 1100, "OPENAI_API_KEY")],
)
assert len(captured) >= 3
final = next(c for c in reversed(captured) if "groq/" not in c.get("model", ""))
schema_count = sum(
1 for m in final["messages"]
if self.GROQ_SCHEMA_MARKER in m.get("content", "")
)
assert schema_count == 0, \
f"Fallback got {schema_count} accumulated schema instructions"
Loading