Skip to content

Commit

Permalink
Backport PR #1048: Implement streaming for /fix (#1059)
Browse files Browse the repository at this point in the history
Co-authored-by: Sanjiv Das <srdas@scu.edu>
  • Loading branch information
meeseeksmachine and srdas authored Oct 28, 2024
1 parent 5c37221 commit 117a7be
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
6 changes: 5 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ async def stream_reply(
self,
input: Input,
human_msg: HumanChatMessage,
pending_msg="Generating response",
config: Optional[RunnableConfig] = None,
):
"""
Expand All @@ -538,6 +539,9 @@ async def stream_reply(
- `config` (optional): A `RunnableConfig` object that specifies
additional configuration when streaming from the runnable.
- `pending_msg` (optional): Changes the default pending message from
"Generating response".
"""
assert self.llm_chain
assert isinstance(self.llm_chain, Runnable)
Expand All @@ -551,7 +555,7 @@ async def stream_reply(
merged_config: RunnableConfig = merge_runnable_configs(base_config, config)

# start with a pending message
with self.pending("Generating response", human_msg) as pending_message:
with self.pending(pending_msg, human_msg) as pending_message:
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
Expand Down
35 changes: 15 additions & 20 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from jupyter_ai.models import CellWithErrorSelection, HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

from .base import BaseChatHandler, SlashCommandRoutingType
Expand Down Expand Up @@ -64,6 +63,7 @@ class FixChatHandler(BaseChatHandler):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_template = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -73,13 +73,11 @@ def create_llm_chain(
**(self.get_model_parameters(provider, provider_params)),
}
llm = provider(**unified_parameters)

self.llm = llm
# TODO: migrate this class to use a LCEL `Runnable` instead of
# `Chain`, then remove the below ignore comment.
self.llm_chain = LLMChain( # type:ignore[arg-type]
llm=llm, prompt=FIX_PROMPT_TEMPLATE, verbose=True
)
prompt_template = FIX_PROMPT_TEMPLATE

runnable = prompt_template | llm # type:ignore
self.llm_chain = runnable

async def process_message(self, message: HumanChatMessage):
if not (message.selection and message.selection.type == "cell-with-error"):
Expand All @@ -96,16 +94,13 @@ async def process_message(self, message: HumanChatMessage):
extra_instructions = message.prompt[4:].strip() or "None."

self.get_llm_chain()
with self.pending("Analyzing error", message):
assert self.llm_chain
# TODO: migrate this class to use a LCEL `Runnable` instead of
# `Chain`, then remove the below ignore comment.
response = await self.llm_chain.apredict( # type:ignore[attr-defined]
extra_instructions=extra_instructions,
stop=["\nHuman:"],
cell_content=selection.source,
error_name=selection.error.name,
error_value=selection.error.value,
traceback="\n".join(selection.error.traceback),
)
self.reply(response, message)
assert self.llm_chain

inputs = {
"extra_instructions": extra_instructions,
"cell_content": selection.source,
"traceback": "\n".join(selection.error.traceback),
"error_name": selection.error.name,
"error_value": selection.error.value,
}
await self.stream_reply(inputs, message, pending_msg="Analyzing error")

0 comments on commit 117a7be

Please sign in to comment.