Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

fix: open intepreter not properly working with ollama #845

Merged
merged 9 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from codegate.db.models import Alert, Output, Prompt
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.utils.utils import get_tool_name_from_messages

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -260,14 +261,20 @@ def get_last_user_message_block(
messages = request["messages"]
block_start_index = None

base_tool = get_tool_name_from_messages(request)
accepted_roles = ["user", "assistant"]
if base_tool == "open interpreter":
# open interpreter also uses the role "tool"
accepted_roles.append("tool")

# Iterate in reverse to find the last block of consecutive 'user' messages
for i in reversed(range(len(messages))):
if messages[i]["role"] == "user" or messages[i]["role"] == "assistant":
if messages[i]["role"] in accepted_roles:
content_str = messages[i].get("content")
if content_str is None:
continue

if messages[i]["role"] == "user":
if messages[i]["role"] in ["user", "tool"]:
user_messages.append(content_str)
block_start_index = i

Expand Down
57 changes: 41 additions & 16 deletions src/codegate/pipeline/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import shlex
from typing import Optional

from litellm import ChatCompletionRequest

Expand Down Expand Up @@ -46,6 +47,42 @@ async def codegate_cli(command):
return await out_func(command[1:])


def _get_cli_from_cline(
codegate_regex: re.Pattern[str], last_user_message_str: str
) -> Optional[re.Match[str]]:
# Check if there are <task> or <feedback> tags
tag_match = re.search(r"<(task|feedback)>(.*?)</\1>", last_user_message_str, re.DOTALL)
if tag_match:
# Extract the content between the tags
stripped_message = tag_match.group(2).strip()
else:
# If no <task> or <feedback> tags, use the entire message
stripped_message = last_user_message_str.strip()

# Remove all other XML tags and trim whitespace
stripped_message = re.sub(r"<[^>]+>", "", stripped_message).strip()

# Check if "codegate" is the first word
match = codegate_regex.match(stripped_message)

return match


def _get_cli_from_open_interpreter(last_user_message_str: str) -> Optional[re.Match[str]]:
# Find all occurrences of "### User:" blocks
user_blocks = re.findall(r"### User:\s*(.*?)(?=\n###|\Z)", last_user_message_str, re.DOTALL)

if user_blocks:
# Extract the last "### User:" block
last_user_block = user_blocks[-1].strip()

# Match "codegate" only in the last "### User:" block
codegate_regex = re.compile(r"^codegate\s*(.*?)\s*$", re.IGNORECASE)
match = codegate_regex.match(last_user_block)
return match
return None


class CodegateCli(PipelineStep):
"""Pipeline step that handles codegate cli."""

Expand Down Expand Up @@ -83,25 +120,13 @@ async def process(
codegate_regex = re.compile(r"^codegate(?:\s+(.*))?", re.IGNORECASE)

if base_tool and base_tool in ["cline", "kodu"]:
# Check if there are <task> or <feedback> tags
tag_match = re.search(
r"<(task|feedback)>(.*?)</\1>", last_user_message_str, re.DOTALL
)
if tag_match:
# Extract the content between the tags
stripped_message = tag_match.group(2).strip()
else:
# If no <task> or <feedback> tags, use the entire message
stripped_message = last_user_message_str.strip()

# Remove all other XML tags and trim whitespace
stripped_message = re.sub(r"<[^>]+>", "", stripped_message).strip()

# Check if "codegate" is the first word
match = codegate_regex.match(stripped_message)
match = _get_cli_from_cline(codegate_regex, last_user_message_str)
elif base_tool == "open interpreter":
match = _get_cli_from_open_interpreter(last_user_message_str)
else:
# Check if "codegate" is the first word in the message
match = codegate_regex.match(last_user_message_str)

if match:
command = match.group(1) or ""
command = command.strip()
Expand Down
67 changes: 38 additions & 29 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def generate_context_str(self, objects: list[object], context: PipelineContext)
)
return context_str

async def process(
async def process( # noqa: C901
self, request: ChatCompletionRequest, context: PipelineContext
) -> PipelineResult:
"""
Expand Down Expand Up @@ -100,9 +100,9 @@ async def process(
)

# split messages into double newlines, to avoid passing so many content in the search
split_messages = re.split(r"</?task>|(\n\n)", user_messages)
split_messages = re.split(r"</?task>|\n|\\n", user_messages)
collected_bad_packages = []
for item_message in split_messages:
for item_message in filter(None, map(str.strip, split_messages)):
# Vector search to find bad packages
bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100)
if bad_packages and len(bad_packages) > 0:
Expand All @@ -128,30 +128,39 @@ async def process(
new_request = request.copy()

# perform replacement in all the messages starting from this index
for i in range(last_user_idx, len(new_request["messages"])):
message = new_request["messages"][i]
message_str = str(message["content"]) # type: ignore
context_msg = message_str
# Add the context to the last user message
base_tool = get_tool_name_from_messages(request)
if base_tool in ["cline", "kodu"]:
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
if match:
task_content = match.group(1) # Content within <task>...</task>
rest_of_message = match.group(2).strip() # Content after </task>, if any

# Embed the context into the task block
updated_task_content = (
f"<task>Context: {context_str}"
+ f"Query: {task_content.strip()}</task>"
)

# Combine updated task content with the rest of the message
context_msg = updated_task_content + rest_of_message

else:
context_msg = f"Context: {context_str} \n\n Query: {message_str}" # type: ignore

new_request["messages"][i]["content"] = context_msg
logger.debug("Final context message", context_message=context_msg)
base_tool = get_tool_name_from_messages(request)
if base_tool != "open interpreter":
for i in range(last_user_idx, len(new_request["messages"])):
message = new_request["messages"][i]
message_str = str(message["content"]) # type: ignore
context_msg = message_str
# Add the context to the last user message
if base_tool in ["cline", "kodu"]:
match = re.search(r"<task>\s*(.*?)\s*</task>(.*)", message_str, re.DOTALL)
if match:
task_content = match.group(1) # Content within <task>...</task>
rest_of_message = match.group(
2
).strip() # Content after </task>, if any

# Embed the context into the task block
updated_task_content = (
f"<task>Context: {context_str}"
+ f"Query: {task_content.strip()}</task>"
)

# Combine updated task content with the rest of the message
context_msg = updated_task_content + rest_of_message
else:
context_msg = f"Context: {context_str} \n\n Query: {message_str}"
new_request["messages"][i]["content"] = context_msg
logger.debug("Final context message", context_message=context_msg)
else:
#  just add a message in the end
new_request["messages"].append(
{
"content": context_str,
"role": "assistant",
}
)
return PipelineResult(request=new_request, context=context)
7 changes: 7 additions & 0 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ def _is_fim_request(self, request: Request, data: Dict) -> bool:
"""
Determine if the request is FIM by the URL or the data of the request.
"""
# first check if we are in specific tools to discard FIM
prompt = data.get("prompt", "")
tools = ["cline", "kodu", "open interpreter"]
for tool in tools:
if tool in prompt.lower():
# those tools can never be FIM
return False
# Avoid more expensive inspection of body by just checking the URL.
if self._is_fim_request_url(request):
return True
Expand Down
1 change: 0 additions & 1 deletion src/codegate/providers/ollama/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
# with a direct HTTP request. Since Ollama is local this is not critical.
# if normalized_data.get("stream", False):
# normalized_data["stream_options"] = {"include_usage": True}

return ChatCompletionRequest(**normalized_data)

def denormalize(self, data: ChatCompletionRequest) -> Dict:
Expand Down
14 changes: 10 additions & 4 deletions src/codegate/providers/ollama/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logger = structlog.get_logger("codegate")


async def ollama_stream_generator(
async def ollama_stream_generator( # noqa: C901
stream: AsyncIterator[ChatResponse], base_tool: str
) -> AsyncIterator[str]:
"""OpenAI-style SSE format"""
Expand All @@ -21,9 +21,7 @@ async def ollama_stream_generator(
# TODO We should wire in the client info so we can respond with
# the correct format and start to handle multiple clients
# in a more robust way.
if base_tool != "cline":
yield f"{chunk.model_dump_json()}\n"
else:
if base_tool in ["cline", "kodu"]:
# First get the raw dict from the chunk
chunk_dict = chunk.model_dump()
# Create response dictionary in OpenAI-like format
Expand Down Expand Up @@ -64,6 +62,14 @@ async def ollama_stream_generator(
response[field] = chunk_dict[field]

yield f"\ndata: {json.dumps(response)}\n"
else:
# if we do not have response, we set it
chunk_dict = chunk.model_dump()
if "response" not in chunk_dict:
chunk_dict["response"] = chunk_dict.get("message", {}).get("content", "\n")
if not chunk_dict["response"]:
chunk_dict["response"] = "\n"
yield f"{json.dumps(chunk_dict)}\n"
except Exception as e:
logger.error(f"Error in stream generator: {str(e)}")
yield f"\ndata: {json.dumps({'error': str(e), 'type': 'error', 'choices': []})}\n"
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_tool_name_from_messages(data):
Returns:
str: The name of the tool found in the messages, or None if no match is found.
"""
tools = ["Cline", "Kodu"]
tools = ["Cline", "Kodu", "Open Interpreter", "Aider"]
for message in data.get("messages", []):
message_content = str(message.get("content", ""))
for tool in tools:
Expand Down
Loading