Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions src/core/app/middleware/json_repair_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from src.core.interfaces.response_processor_interface import (
IResponseMiddleware,
)
from src.core.services.json_repair_service import JsonRepairService
from src.core.services.json_repair_service import (
JsonRepairResult,
JsonRepairService,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,34 +72,36 @@ async def process(
)

try:
repaired_json = self.json_repair_service.repair_and_validate_json(
response.content,
schema=self.config.session.json_repair_schema,
strict=strict_effective,
)
if repaired_json is not None:
metrics.inc(
"json_repair.non_streaming.strict_success"
if strict_effective
else "json_repair.non_streaming.best_effort_success"
repair_result: JsonRepairResult = (
self.json_repair_service.repair_and_validate_json(
response.content,
schema=self.config.session.json_repair_schema,
strict=strict_effective,
)
else:
metrics.inc(
"json_repair.non_streaming.strict_fail"
if strict_effective
else "json_repair.non_streaming.best_effort_fail"
)
metric_suffix = (
"strict_success"
if strict_effective and repair_result.success
else (
"best_effort_success"
if repair_result.success
else (
"strict_fail" if strict_effective else "best_effort_fail"
)
)
)
metrics.inc(f"json_repair.non_streaming.{metric_suffix}")
except Exception:
metrics.inc(
"json_repair.non_streaming.strict_fail"
if strict_effective
else "json_repair.non_streaming.best_effort_fail"
)
raise
if repaired_json is not None:
if repair_result.success:
if logger.isEnabledFor(logging.INFO):
logger.info(f"JSON detected and repaired for session {session_id}")
response.content = json.dumps(repaired_json)
response.content = json.dumps(repair_result.content)
response.metadata["repaired"] = True

return response
29 changes: 19 additions & 10 deletions src/core/services/json_repair_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import json
import logging
from typing import Any, cast
from dataclasses import dataclass
from typing import Any

from json_repair import repair_json
from jsonschema import ValidationError as JsonSchemaValidationError
Expand All @@ -13,6 +14,14 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class JsonRepairResult:
"""Represents the outcome of a JSON repair attempt."""

success: bool
content: Any | None


class JsonRepairService:
"""
A service to repair and validate JSON data.
Expand All @@ -25,7 +34,7 @@
json_string: str,
schema: dict[str, Any] | None = None,
strict: bool = False,
) -> dict[str, Any] | None:
) -> JsonRepairResult:
"""
Repairs a JSON string and optionally validates it against a schema.

Expand All @@ -35,13 +44,13 @@
strict: If True, raises an error if the JSON is invalid after repair.

Returns:
The repaired and validated JSON object, or None if repair fails.
JsonRepairResult describing whether repair succeeded and the content.
"""
try:
repaired_json = self.repair_json(json_string)
if schema:
if schema is not None:
self.validate_json(repaired_json, schema)
return repaired_json
return JsonRepairResult(success=True, content=repaired_json)
except JsonSchemaValidationError as e:
if strict:
raise ValidationError(
Expand All @@ -57,7 +66,7 @@
},
) from e
logger.warning("JSON schema validation failed: %s", e)
return None
return JsonRepairResult(success=False, content=repaired_json)
except (ValueError, TypeError) as e:
if strict:
raise JSONParsingError(
Expand All @@ -67,10 +76,10 @@
"error_message": str(e),
},
) from e
logger.warning(f"Failed to repair or validate JSON: {e}")
return None
logger.warning("Failed to repair or validate JSON: %s", e)
return JsonRepairResult(success=False, content=None)

Check warning on line 80 in src/core/services/json_repair_service.py

View check run for this annotation

Codecov / codecov/patch

src/core/services/json_repair_service.py#L79-L80

Added lines #L79 - L80 were not covered by tests

def repair_json(self, json_string: str) -> dict[str, Any]:
def repair_json(self, json_string: str) -> Any:
"""
Repairs a JSON string.

Expand All @@ -81,7 +90,7 @@
The repaired JSON object.
"""
repaired_string = repair_json(json_string)
return cast(dict[str, Any], json.loads(repaired_string))
return json.loads(repaired_string)

def validate_json(
self, json_object: dict[str, Any], schema: dict[str, Any]
Expand Down
27 changes: 12 additions & 15 deletions src/core/services/streaming/json_repair_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
IStreamProcessor,
StreamingContent,
)
from src.core.services.json_repair_service import JsonRepairService
from src.core.services.json_repair_service import JsonRepairResult, JsonRepairService

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,9 +67,9 @@ async def process(self, content: StreamingContent) -> StreamingContent:
else:
i = self._process_json_character(text, i)
if self._is_json_complete():
repaired_json, success = self._handle_json_completion()
if success:
out_parts.append(json.dumps(repaired_json))
repair_result = self._handle_json_completion()
if repair_result.success:
out_parts.append(json.dumps(repair_result.content))
else:
out_parts.append(self._buffer)
self._reset_state()
Expand Down Expand Up @@ -151,26 +151,23 @@ def _is_current_quote_escaped(self) -> bool:
def _is_json_complete(self) -> bool:
return self._json_started and self._brace_level == 0 and not self._in_string

def _handle_json_completion(self) -> tuple[Any, bool]:
repaired = None
success = False
def _handle_json_completion(self) -> JsonRepairResult:
try:
repaired = self._service.repair_and_validate_json(
result = self._service.repair_and_validate_json(
self._buffer,
schema=self._schema,
strict=self._strict_mode,
)
if repaired is not None:
success = True
except Exception as e: # pragma: no cover - strict mode rethrow
if self._strict_mode:
raise JSONParsingError(
message=f"JSON repair failed in strict mode: {e}",
details={"original_buffer": self._buffer},
) from e
logger.warning("JSON repair raised error: %s", e)
return JsonRepairResult(success=False, content=None)

if repaired is not None:
if result.success:
metrics.inc(
"json_repair.streaming.strict_success"
if self._strict_mode
Expand All @@ -185,7 +182,7 @@ def _handle_json_completion(self) -> tuple[Any, bool]:
logger.warning(
"JSON block detected but failed to repair. Flushing raw buffer."
)
return repaired, success
return result

def _log_buffer_capacity_warning(self) -> None:
if self._json_started and len(self._buffer) > self._buffer_cap_bytes:
Expand All @@ -199,16 +196,16 @@ def _flush_final_buffer(self) -> str | None:
if not self._in_string and buf.rstrip().endswith(":"):
buf = buf + " null"
self._buffer = buf
repaired_final = self._service.repair_and_validate_json(
repair_result = self._service.repair_and_validate_json(
buf, schema=self._schema, strict=self._strict_mode
)
if repaired_final is not None:
if repair_result.success:
metrics.inc(
"json_repair.streaming.strict_success"
if self._strict_mode
else "json_repair.streaming.best_effort_success"
)
return json.dumps(repaired_final)
return json.dumps(repair_result.content)
else:
metrics.inc(
"json_repair.streaming.strict_fail"
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/core/services/test_json_repair_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ async def test_process_response_empty_object(
assert processed_response.metadata.get("repaired") is True


async def test_process_response_null_payload(
json_repair_middleware: JsonRepairMiddleware,
) -> None:
response = ProcessedResponse(content="null")
processed_response = await json_repair_middleware.process(
response, "session_id", {}
)

assert processed_response.content == "null"
assert processed_response.metadata.get("repaired") is True


async def test_process_response_best_effort_failure_metrics(
json_repair_middleware: JsonRepairMiddleware,
monkeypatch: pytest.MonkeyPatch,
Expand Down
14 changes: 12 additions & 2 deletions tests/unit/core/services/test_json_repair_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,21 @@ def test_repair_and_validate_json_schema_failure_best_effort(
"required": ["a"],
}

repaired = json_repair_service.repair_and_validate_json(
result = json_repair_service.repair_and_validate_json(
'{"a": "text"}', schema=schema, strict=False
)

assert repaired is None
assert result.success is False
assert result.content == {"a": "text"}


def test_repair_and_validate_json_allows_null_payload(
json_repair_service: JsonRepairService,
) -> None:
result = json_repair_service.repair_and_validate_json("null")

assert result.success is True
assert result.content is None


def test_repair_and_validate_json_schema_failure_strict(
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/json_repair_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from src.core.domain.streaming_content import StreamingContent
from src.core.services.json_repair_service import JsonRepairService
from src.core.services.json_repair_service import JsonRepairResult, JsonRepairService
from src.core.services.streaming.json_repair_processor import JsonRepairProcessor


Expand All @@ -16,11 +16,11 @@ def repair_and_validate_json(
json_string: str,
schema: dict[str, Any] | None = None,
strict: bool = False,
) -> dict[str, Any] | None:
return None
) -> JsonRepairResult:
return JsonRepairResult(success=False, content=None)


def test_json_repair_processor_flushes_raw_buffer_when_repair_returns_none() -> None:
def test_json_repair_processor_flushes_raw_buffer_when_repair_fails() -> None:
processor = JsonRepairProcessor(
repair_service=FailingJsonRepairService(),
buffer_cap_bytes=1024,
Expand Down