Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Fix FIM for OpenRouter #1097

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/codegate/providers/openai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class OpenAIProvider(BaseProvider):
def __init__(
self,
pipeline_factory: PipelineFactory,
# Enable receiving other completion handlers from childs, i.e. OpenRouter and LM Studio
completion_handler: LiteLLmShim = LiteLLmShim(stream_generator=sse_stream_generator),
):
completion_handler = LiteLLmShim(stream_generator=sse_stream_generator)
super().__init__(
OpenAIInputNormalizer(),
OpenAIOutputNormalizer(),
Expand Down
42 changes: 37 additions & 5 deletions src/codegate/providers/openrouter/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from typing import Dict

from fastapi import Header, HTTPException, Request
from litellm import atext_completion
from litellm.types.llms.openai import ChatCompletionRequest

from codegate.clients.clients import ClientType
from codegate.clients.detector import DetectClient
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
from codegate.providers.normalizer.completion import CompletionNormalizer
from codegate.providers.openai import OpenAIProvider

Expand All @@ -20,15 +22,45 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
return super().normalize(data)

def denormalize(self, data: ChatCompletionRequest) -> Dict:
if data.get("had_prompt_before", False):
del data["had_prompt_before"]

return data
"""
Denormalize a FIM OpenRouter request. Force it to be an accepted atext_completion format.
"""
denormalized_data = super().denormalize(data)
# We are forcing atext_completion which expects to have a "prompt" key in the data
# Forcing it in case is not present
if "prompt" in data:
return denormalized_data
custom_prompt = ""
for msg_dict in denormalized_data.get("messages", []):
content_obj = msg_dict.get("content")
if not content_obj:
continue
if isinstance(content_obj, list):
for content_dict in content_obj:
custom_prompt += (
content_dict.get("text", "") if isinstance(content_dict, dict) else ""
)
elif isinstance(content_obj, str):
custom_prompt += content_obj

# Erase the original "messages" key. Replace it by "prompt"
del denormalized_data["messages"]
denormalized_data["prompt"] = custom_prompt

return denormalized_data


class OpenRouterProvider(OpenAIProvider):
def __init__(self, pipeline_factory: PipelineFactory):
super().__init__(pipeline_factory)
super().__init__(
pipeline_factory,
# We get FIM requests in /completions. LiteLLM is forcing /chat/completions
# which returns "choices":[{"delta":{"content":"some text"}}]
# instead of "choices":[{"text":"some text"}] expected by the client (Continue)
completion_handler=LiteLLmShim(
stream_generator=sse_stream_generator, fim_completion_func=atext_completion
),
)
self._fim_normalizer = OpenRouterNormalizer()

@property
Expand Down
Loading