Skip to content

fix: use the latest user messages block instead of single message #585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data/archived.jsonl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{"name":"@prefix/archived-npm-dummy","type":"npm","description":"Dummy archived to test with encoded package name on npm"}
{"name":"archived-npm-dummy","type":"npm","description":"Dummy archived to test with simple package name on npm"}
{"name":"@prefix/archived-pypi-dummy","type":"pypi","description":"Dummy archived to test with encoded package name on pypi"}
{"name":"archived-pypi-dummy","type":"pypi","description":"Dummy archived to test with simple package name on pypi"}
{"name":"archived_pypi_dummy","type":"pypi","description":"Dummy archived to test with simple package name on pypi"}
{"name":"@prefix/archived-maven-dummy","type":"maven","description":"Dummy archived to test with encoded package name on maven"}
{"name":"archived-maven-dummy","type":"maven","description":"Dummy archived to test with simple package name on maven"}
{"name":"github.com/archived-go-dummy","type":"npm","description":"Dummy archived to test with encoded package name on go"}
Expand Down
2 changes: 1 addition & 1 deletion data/deprecated.jsonl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{"name":"@prefix/deprecated-npm-dummy","type":"npm","description":"Dummy deprecated to test with encoded package name on npm"}
{"name":"deprecated-npm-dummy","type":"npm","description":"Dummy deprecated to test with simple package name on npm"}
{"name":"@prefix/deprecated-pypi-dummy","type":"pypi","description":"Dummy deprecated to test with encoded package name on pypi"}
{"name":"deprecated-pypi-dummy","type":"pypi","description":"Dummy deprecated to test with simple package name on pypi"}
{"name":"deprecated_pypi_dummy","type":"pypi","description":"Dummy deprecated to test with simple package name on pypi"}
{"name":"@prefix/deprecated-maven-dummy","type":"maven","description":"Dummy deprecated to test with encoded package name on maven"}
{"name":"deprecated-maven-dummy","type":"maven","description":"Dummy deprecated to test with simple package name on maven"}
{"name":"github.com/deprecated-go-dummy","type":"npm","description":"Dummy deprecated to test with encoded package name on go"}
Expand Down
2 changes: 1 addition & 1 deletion data/malicious.jsonl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{"name":"@prefix/malicious-npm-dummy","type":"npm","description":"Dummy malicious to test with encoded package name on npm"}
{"name":"malicious-npm-dummy","type":"npm","description":"Dummy malicious to test with simple package name on npm"}
{"name":"@prefix/malicious-pypi-dummy","type":"pypi","description":"Dummy malicious to test with encoded package name on pypi"}
{"name":"malicious-pypi-dummy","type":"pypi","description":"Dummy malicious to test with simple package name on pypi"}
{"name":"malicious_pypi_dummy","type":"pypi","description":"Dummy malicious to test with simple package name on pypi"}
{"name":"@prefix/malicious-maven-dummy","type":"maven","description":"Dummy malicious to test with encoded package name on maven"}
{"name":"malicious-maven-dummy","type":"maven","description":"Dummy malicious to test with simple package name on maven"}
{"name":"github.com/malicious-go-dummy","type":"go","description":"Dummy malicious to test with encoded package name on go"}
Expand Down
110 changes: 54 additions & 56 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tree-sitter-python = ">=0.23.6"
tree-sitter-rust = ">=0.23.2"
sqlite-vec-sl-tmp = "^0.0.4"
alembic = ">=1.14.0"
pygments = "^2.19.1"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.4.0"
Expand Down
50 changes: 48 additions & 2 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,54 @@ def get_last_user_message(
return None
for i in reversed(range(len(request["messages"]))):
if request["messages"][i]["role"] == "user":
content = request["messages"][i]["content"]
return content, i
content = request["messages"][i]["content"] # type: ignore
return str(content), i

return None

@staticmethod
def get_last_user_message_block(
request: ChatCompletionRequest,
) -> Optional[str]:
"""
Get the last block of consecutive 'user' messages from the request.

Args:
request (ChatCompletionRequest): The chat completion request to process

Returns:
Optional[str]: A string containing all consecutive user messages in the
last user message block, separated by newlines, or None if
no user message block is found.
"""
if request.get("messages") is None:
return None

user_messages = []
messages = request["messages"]

# Iterate in reverse to find the last block of consecutive 'user' messages
for i in reversed(range(len(messages))):
if messages[i]["role"] == "user" or messages[i]["role"] == "assistant":
content_str = None
if "content" in messages[i]:
content_str = messages[i]["content"] # type: ignore
else:
continue

if messages[i]["role"] == "user":
user_messages.append(content_str)
# specifically for Aider, when "ok." block is found, stop
if content_str == "Ok." and messages[i]["role"] == "assistant":
break
else:
# Stop when a message with a different role is encountered
if user_messages:
break

# Reverse the collected user messages to preserve the original order
if user_messages:
return "\n".join(reversed(user_messages))

return None

Expand Down
41 changes: 25 additions & 16 deletions src/codegate/pipeline/codegate_context_retriever/codegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,44 +59,53 @@ async def process(
"""
Use RAG DB to add context to the user request
"""
# Get the latest user messages
user_messages = self.get_latest_user_messages(request)

# Nothing to do if the user_messages string is empty
if len(user_messages) == 0:
# Get the latest user message
user_message = self.get_last_user_message_block(request)
if not user_message:
return PipelineResult(request=request)

# Create storage engine object
storage_engine = StorageEngine()

# Extract any code snippets
snippets = extract_snippets(user_messages)
snippets = extract_snippets(user_message)

bad_snippet_packages = []
if len(snippets) > 0:
snippet_language = snippets[0].language
# Collect all packages referenced in the snippets
snippet_packages = []
for snippet in snippets:
snippet_packages.extend(
PackageExtractor.extract_packages(snippet.code, snippet.language)
PackageExtractor.extract_packages(snippet.code, snippet.language) # type: ignore
)
logger.info(f"Found {len(snippet_packages)} packages in code snippets.")

logger.info(
f"Found {len(snippet_packages)} packages "
f"for language {snippet_language} in code snippets."
)
# Find bad packages in the snippets
bad_snippet_packages = await storage_engine.search(
language=snippets[0].language, packages=snippet_packages
)
language=snippet_language, packages=snippet_packages
) # type: ignore
logger.info(f"Found {len(bad_snippet_packages)} bad packages in code snippets.")

# Remove code snippets from the user messages and search for bad packages
# in the rest of the user query/messsages
user_messages = re.sub(r"```.*?```", "", user_messages, flags=re.DOTALL)

# Vector search to find bad packages
bad_packages = await storage_engine.search(query=user_messages, distance=0.5, limit=100)
user_messages = re.sub(r"```.*?```", "", user_message, flags=re.DOTALL)
user_messages = re.sub(r"⋮...*?⋮...\n\n", "", user_messages, flags=re.DOTALL)

# split messages into double newlines, to avoid passing so many content in the search
split_messages = user_messages.split("\n\n")
collected_bad_packages = []
for item_message in split_messages:
# Vector search to find bad packages
bad_packages = await storage_engine.search(query=item_message, distance=0.5, limit=100)
if bad_packages and len(bad_packages) > 0:
collected_bad_packages.extend(bad_packages)

# All bad packages
all_bad_packages = bad_snippet_packages + bad_packages
all_bad_packages = bad_snippet_packages + collected_bad_packages

logger.info(f"Adding {len(all_bad_packages)} bad packages to the context.")

Expand All @@ -119,7 +128,7 @@ async def process(
# Add the context to the last user message
# Format: "Context: {context_str} \n Query: {last user message content}"
message = new_request["messages"][last_user_idx]
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}'
context_msg = f'Context: {context_str} \n\n Query: {message["content"]}' # type: ignore
message["content"] = context_msg

logger.debug("Final context message", context_message=context_msg)
Expand Down
17 changes: 14 additions & 3 deletions src/codegate/pipeline/extract_snippets/extract_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import structlog
from litellm.types.llms.openai import ChatCompletionRequest
from pygments.lexers import guess_lexer

from codegate.pipeline.base import CodeSnippet, PipelineContext, PipelineResult, PipelineStep

Expand Down Expand Up @@ -65,6 +66,8 @@ def ecosystem_from_message(message: str) -> Optional[str]:
"ts": "typescript",
"tsx": "typescript",
"go": "go",
"rs": "rust",
"java": "java",
}
return language_mapping.get(message, None)

Expand All @@ -82,6 +85,7 @@ def extract_snippets(message: str) -> List[CodeSnippet]:
# Regular expression to find code blocks

snippets: List[CodeSnippet] = []
available_languages = ["python", "javascript", "typescript", "go", "rust", "java"]

# Find all code block matches
for match in CODE_BLOCK_PATTERN.finditer(message):
Expand All @@ -105,6 +109,14 @@ def extract_snippets(message: str) -> List[CodeSnippet]:
filename = filename.strip()
# Determine language from the filename
lang = ecosystem_from_filepath(filename)
if lang is None:
# try to guess it from the code
lexer = guess_lexer(content)
if lexer and lexer.name:
lang = lexer.name.lower()
# only add available languages
if lang not in available_languages:
lang = None

snippets.append(CodeSnippet(filepath=filename, code=content, language=lang))

Expand All @@ -129,10 +141,9 @@ async def process(
request: ChatCompletionRequest,
context: PipelineContext,
) -> PipelineResult:
last_user_message = self.get_last_user_message(request)
if not last_user_message:
msg_content = self.get_last_user_message_block(request)
if not msg_content:
return PipelineResult(request=request, context=context)
msg_content, _ = last_user_message
snippets = extract_snippets(msg_content)

logger.info(f"Extracted {len(snippets)} code snippets from the user message")
Expand Down
Loading
Loading