Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Add client detection #832

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/codegate/clients/clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Enum


class ClientType(Enum):
"""
Enum of supported client types
"""

GENERIC = "generic" # Default client type when no specific client is detected
CLINE = "cline" # Cline client
KODU = "kodu" # Kodu client
COPILOT = "copilot" # Copilot client
OPEN_INTERPRETER = "open_interpreter" # Open Interpreter client
224 changes: 224 additions & 0 deletions src/codegate/clients/detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import re
from abc import ABC, abstractmethod
from functools import wraps
from typing import List, Optional

import structlog
from fastapi import Request

from codegate.clients.clients import ClientType

logger = structlog.get_logger("codegate")


class HeaderDetector:
"""
Base utility class for header-based detection
"""

def __init__(self, header_name: str, header_value: Optional[str] = None):
self.header_name = header_name
self.header_value = header_value

def detect(self, request: Request) -> bool:
logger.debug(
"checking header detection",
header_name=self.header_name,
header_value=self.header_value,
request_headers=dict(request.headers),
)
# Check if the header is present, if not we didn't detect the client
if self.header_name not in request.headers:
return False
# now we know that the header is present, if we don't care about the value
# we detected the client
if self.header_value is None:
return True
# finally, if we care about the value, we need to check if it matches
return request.headers[self.header_name] == self.header_value


class UserAgentDetector(HeaderDetector):
"""
A variant of the HeaderDetector that specifically looks for a user-agent pattern
"""

def __init__(self, user_agent_pattern: str):
super().__init__("user-agent")
self.pattern = re.compile(user_agent_pattern, re.IGNORECASE)

def detect(self, request: Request) -> bool:
user_agent = request.headers.get(self.header_name)
if not user_agent:
return False
return bool(self.pattern.search(user_agent))


class ContentDetector:
"""
Detector for message content patterns
"""

def __init__(self, pattern: str):
self.pattern = pattern

async def detect(self, request: Request) -> bool:
try:
data = await request.json()
for message in data.get("messages", []):
message_content = str(message.get("content", ""))
if self.pattern in message_content:
return True
# This is clearly a hack and won't be needed when we get rid of the normalizers and will
# be able to access the system message directly from the on-wire format
system_content = str(data.get("system", ""))
if self.pattern in system_content:
return True
return False
except Exception as e:
logger.error(f"Error in content detection: {str(e)}")
return False


class BaseClientDetector(ABC):
"""
Base class for all client detectors using composition of detection methods
"""

def __init__(self):
self.header_detector: Optional[HeaderDetector] = None
self.user_agent_detector: Optional[UserAgentDetector] = None
self.content_detector: Optional[ContentDetector] = None

@property
@abstractmethod
def client_name(self) -> ClientType:
"""
Returns the name of the client
"""
pass

async def detect(self, request: Request) -> bool:
"""
Tries each configured detection method in sequence
"""
# Try user agent first if configured
if self.user_agent_detector and self.user_agent_detector.detect(request):
return True

# Then try header if configured
if self.header_detector and self.header_detector.detect(request):
return True

# Finally try content if configured
if self.content_detector:
return await self.content_detector.detect(request)

return False


class ClineDetector(BaseClientDetector):
"""
Detector for Cline client based on message content
"""

def __init__(self):
super().__init__()
self.content_detector = ContentDetector("Cline")

@property
def client_name(self) -> ClientType:
return ClientType.CLINE


class KoduDetector(BaseClientDetector):
"""
Detector for Kodu client based on message content
"""

def __init__(self):
super().__init__()
self.user_agent_detector = UserAgentDetector("Kodu")
self.content_detector = ContentDetector("Kodu")

@property
def client_name(self) -> ClientType:
return ClientType.KODU


class OpenInterpreter(BaseClientDetector):
"""
Detector for Kodu client based on message content
"""

def __init__(self):
super().__init__()
self.content_detector = ContentDetector("Open Interpreter")

@property
def client_name(self) -> ClientType:
return ClientType.OPEN_INTERPRETER


class CopilotDetector(HeaderDetector):
"""
Detector for Copilot client based on user agent
"""

def __init__(self):
super().__init__("user-agent", "Copilot")

@property
def client_name(self) -> ClientType:
return ClientType.COPILOT


class DetectClient:
"""
Decorator class for detecting clients from request system messages

Usage:
@app.post("/v1/chat/completions")
@DetectClient()
async def chat_completions(request: Request):
client = request.state.detected_client
"""

def __init__(self):
self.detectors: List[BaseClientDetector] = [
ClineDetector(),
KoduDetector(),
OpenInterpreter(),
CopilotDetector(),
]

def __call__(self, func):
@wraps(func)
async def wrapper(request: Request, *args, **kwargs):
try:
client = await self.detect(request)
request.state.detected_client = client
except Exception as e:
logger.error(f"Error in client detection: {str(e)}")
request.state.detected_client = ClientType.GENERIC

return await func(request, *args, **kwargs)

return wrapper

async def detect(self, request: Request) -> ClientType:
"""
Detects the client from the request by trying each detector in sequence.
Returns the name of the first detected client, or GENERIC if no specific client is detected.
"""
for detector in self.detectors:
try:
if await detector.detect(request):
client_name = detector.client_name
logger.info(f"{client_name} client detected")
return client_name
except Exception as e:
logger.error(f"Error in {detector.client_name} detection: {str(e)}")
continue
logger.info("No particilar client detected, using generic client")
return ClientType.GENERIC
46 changes: 33 additions & 13 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from litellm import ChatCompletionRequest, ModelResponse
from pydantic import BaseModel

from codegate.clients.clients import ClientType
from codegate.db.models import Alert, Output, Prompt
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.utils.utils import get_tool_name_from_messages

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -81,6 +81,7 @@ class PipelineContext:
shortcut_response: bool = False
bad_packages_found: bool = False
secrets_found: bool = False
client: ClientType = ClientType.GENERIC

def add_code_snippet(self, snippet: CodeSnippet):
self.code_snippets.append(snippet)
Expand Down Expand Up @@ -241,12 +242,14 @@ def get_last_user_message(
@staticmethod
def get_last_user_message_block(
request: ChatCompletionRequest,
client: ClientType = ClientType.GENERIC,
) -> Optional[tuple[str, int]]:
"""
Get the last block of consecutive 'user' messages from the request.

Args:
request (ChatCompletionRequest): The chat completion request to process
client (ClientType): The client type to consider when processing the request

Returns:
Optional[str, int]: A string containing all consecutive user messages in the
Expand All @@ -261,9 +264,8 @@ def get_last_user_message_block(
messages = request["messages"]
block_start_index = None

base_tool = get_tool_name_from_messages(request)
accepted_roles = ["user", "assistant"]
if base_tool == "open interpreter":
if client == ClientType.OPEN_INTERPRETER:
# open interpreter also uses the role "tool"
accepted_roles.append("tool")

Expand Down Expand Up @@ -303,12 +305,16 @@ async def process(

class InputPipelineInstance:
def __init__(
self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool
self,
pipeline_steps: List[PipelineStep],
secret_manager: SecretsManager,
is_fim: bool,
client: ClientType = ClientType.GENERIC,
):
self.pipeline_steps = pipeline_steps
self.secret_manager = secret_manager
self.is_fim = is_fim
self.context = PipelineContext()
self.context = PipelineContext(client=client)

# we create the sesitive context here so that it is not shared between individual requests
# TODO: could we get away with just generating the session ID for an instance?
Expand All @@ -326,7 +332,6 @@ async def process_request(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
is_copilot: bool = False,
) -> PipelineResult:
"""Process a request through all pipeline steps"""
self.context.metadata["extra_headers"] = extra_headers
Expand All @@ -338,7 +343,9 @@ async def process_request(
self.context.sensitive.api_base = api_base

# For Copilot provider=openai. Use a flag to not clash with other places that may use that.
provider_db = "copilot" if is_copilot else provider
provider_db = provider
if self.context.client == ClientType.COPILOT:
provider_db = "copilot"

for step in self.pipeline_steps:
result = await step.process(current_request, self.context)
Expand Down Expand Up @@ -367,16 +374,25 @@ async def process_request(

class SequentialPipelineProcessor:
def __init__(
self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager, is_fim: bool
self,
pipeline_steps: List[PipelineStep],
secret_manager: SecretsManager,
client_type: ClientType,
is_fim: bool,
):
self.pipeline_steps = pipeline_steps
self.secret_manager = secret_manager
self.is_fim = is_fim
self.instance = self._create_instance()
self.instance = self._create_instance(client_type)

def _create_instance(self) -> InputPipelineInstance:
def _create_instance(self, client_type: ClientType) -> InputPipelineInstance:
"""Create a new pipeline instance for processing a request"""
return InputPipelineInstance(self.pipeline_steps, self.secret_manager, self.is_fim)
return InputPipelineInstance(
self.pipeline_steps,
self.secret_manager,
self.is_fim,
client_type,
)

async def process_request(
self,
Expand All @@ -386,9 +402,13 @@ async def process_request(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
is_copilot: bool = False,
) -> PipelineResult:
"""Create a new pipeline instance and process the request"""
return await self.instance.process_request(
request, provider, model, api_key, api_base, extra_headers, is_copilot
request,
provider,
model,
api_key,
api_base,
extra_headers,
)
Loading