Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Optionally dump requests in providers #847

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import datetime
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union

import structlog
from fastapi import APIRouter, Request
from litellm import ModelResponse
from litellm.types.llms.openai import ChatCompletionRequest

from codegate.codegate_logging import setup_logging
from codegate.db.connection import DbRecorder
from codegate.pipeline.base import (
PipelineContext,
Expand All @@ -19,8 +24,14 @@
from codegate.providers.normalizer.completion import CompletionNormalizer
from codegate.utils.utils import get_tool_name_from_messages

setup_logging()
logger = structlog.get_logger("codegate")

TEMPDIR = None
if os.getenv("CODEGATE_DUMP_DIR"):
basedir = os.getenv("CODEGATE_DUMP_DIR")
TEMPDIR = tempfile.TemporaryDirectory(prefix="codegate-", dir=basedir, delete=False)

StreamGenerator = Callable[[AsyncIterator[Any]], AsyncIterator[str]]


Expand Down Expand Up @@ -205,6 +216,26 @@ async def _cleanup_after_streaming(
if context.sensitive:
context.sensitive.secure_cleanup()

def _dump_request_response(self, prefix: str, data: Any) -> None:
"""Dump request or response data to a file if CODEGATE_DUMP_DIR is set"""
if not TEMPDIR:
return

ts = datetime.datetime.now()
fname = (
Path(TEMPDIR.name)
/ f"{prefix}-{self.provider_route_name}-{ts.strftime('%Y%m%dT%H%M%S%f')}.json"
)

if isinstance(data, (dict, list)):
import json

with open(fname, "w") as f:
json.dump(data, f, indent=2)
else:
with open(fname, "w") as f:
f.write(str(data))

async def complete(
self, data: Dict, api_key: Optional[str], is_fim_request: bool
) -> Union[ModelResponse, AsyncIterator[ModelResponse]]:
Expand All @@ -219,7 +250,11 @@ async def complete(
- Execute the completion and translate the response back to the
provider-specific format
"""
# Dump the incoming request
self._dump_request_response("request", data)
normalized_request = self._input_normalizer.normalize(data)
# Dump the normalized request
self._dump_request_response("normalized-request", normalized_request)
streaming = normalized_request.get("stream", False)
input_pipeline_result = await self._run_input_pipeline(
normalized_request,
Expand All @@ -237,6 +272,8 @@ async def complete(
if is_fim_request:
provider_request = self._fim_normalizer.denormalize(provider_request) # type: ignore

self._dump_request_response("provider-request", provider_request)

# Execute the completion and translate the response
# This gives us either a single response or a stream of responses
# based on the streaming flag
Expand Down