Skip to content

Commit 34051ba

Browse files
authored
Improve support for custom providers (#713)
* improve support for custom providers adds 3 new provider class attributes: 1. `manages_history` 2. `unsupported_slash_commands` 3. `persona` * pre-commit * add comment about jupyternaut icon in frontend * remove 'avatar_path' from 'Persona', drop 'PersonaDescription' * pre-commit
1 parent 144cc9b commit 34051ba

File tree

13 files changed

+195
-22
lines changed

13 files changed

+195
-22
lines changed

packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from .exception import store_exception
1313
from .magics import AiMagics
1414

15+
# expose JupyternautPersona on the package root
16+
# required by `jupyter-ai`.
17+
from .models.persona import JupyternautPersona, Persona
18+
1519
# expose model providers on the package root
1620
from .providers import (
1721
AI21Provider,
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from langchain.pydantic_v1 import BaseModel
2+
3+
4+
class Persona(BaseModel):
5+
"""
6+
Model of an **agent persona**, a struct that includes the name & avatar
7+
shown on agent replies in the chat UI.
8+
9+
Each persona is specific to a single provider, set on the `persona` field.
10+
"""
11+
12+
name: str = ...
13+
"""
14+
Name of the persona, e.g. "Jupyternaut". This is used to render the name
15+
shown on agent replies in the chat UI.
16+
"""
17+
18+
avatar_route: str = ...
19+
"""
20+
The server route that should be used the avatar of this persona. This is
21+
used to render the avatar shown on agent replies in the chat UI.
22+
"""
23+
24+
25+
JUPYTERNAUT_AVATAR_ROUTE = "api/ai/static/jupyternaut.svg"
26+
JupyternautPersona = Persona(name="Jupyternaut", avatar_route=JUPYTERNAUT_AVATAR_ROUTE)

packages/jupyter-ai-magics/jupyter_ai_magics/providers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
except:
4747
from pydantic.main import ModelMetaclass
4848

49+
from .models.persona import Persona
4950

5051
CHAT_SYSTEM_PROMPT = """
5152
You are Jupyternaut, a conversational assistant living in JupyterLab to help users.
@@ -214,6 +215,30 @@ class Config:
214215
"""User inputs expected by this provider when initializing it. Each `Field` `f`
215216
should be passed in the constructor as a keyword argument, keyed by `f.key`."""
216217

218+
manages_history: ClassVar[bool] = False
219+
"""Whether this provider manages its own conversation history upstream. If
220+
set to `True`, Jupyter AI will not pass the chat history to this provider
221+
when invoked."""
222+
223+
persona: ClassVar[Optional[Persona]] = None
224+
"""
225+
The **persona** of this provider, a struct that defines the name and avatar
226+
shown on agent replies in the chat UI. When set to `None`, `jupyter-ai` will
227+
choose a default persona when rendering agent messages by this provider.
228+
229+
Because this field is set to `None` by default, `jupyter-ai` will render a
230+
default persona for all providers that are included natively with the
231+
`jupyter-ai` package. This field is reserved for Jupyter AI modules that
232+
serve a custom provider and want to distinguish it in the chat UI.
233+
"""
234+
235+
unsupported_slash_commands: ClassVar[set] = {}
236+
"""
237+
A set of slash commands unsupported by this provider. Unsupported slash
238+
commands are not shown in the help message, and cannot be used while this
239+
provider is selected.
240+
"""
241+
217242
#
218243
# instance attrs
219244
#

packages/jupyter-ai/jupyter_ai/chat_handlers/base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from dask.distributed import Client as DaskClient
1818
from jupyter_ai.config_manager import ConfigManager, Logger
1919
from jupyter_ai.models import AgentChatMessage, ChatMessage, HumanChatMessage
20+
from jupyter_ai_magics import Persona
2021
from jupyter_ai_magics.providers import BaseProvider
2122
from langchain.pydantic_v1 import BaseModel
2223

@@ -94,10 +95,21 @@ async def on_message(self, message: HumanChatMessage):
9495
`self.handle_exc()` when an exception is raised. This method is called
9596
by RootChatHandler when it routes a human message to this chat handler.
9697
"""
98+
lm_provider_klass = self.config_manager.lm_provider
99+
100+
# ensure the current slash command is supported
101+
if self.routing_type.routing_method == "slash_command":
102+
slash_command = (
103+
"/" + self.routing_type.slash_id if self.routing_type.slash_id else ""
104+
)
105+
if slash_command in lm_provider_klass.unsupported_slash_commands:
106+
self.reply(
107+
"Sorry, the selected language model does not support this slash command."
108+
)
109+
return
97110

98111
# check whether the configured LLM can support a request at this time.
99112
if self.uses_llm and BaseChatHandler._requests_count > 0:
100-
lm_provider_klass = self.config_manager.lm_provider
101113
lm_provider_params = self.config_manager.lm_provider_params
102114
lm_provider = lm_provider_klass(**lm_provider_params)
103115

@@ -159,11 +171,18 @@ async def _default_handle_exc(self, e: Exception, message: HumanChatMessage):
159171
self.reply(response, message)
160172

161173
def reply(self, response: str, human_msg: Optional[HumanChatMessage] = None):
174+
"""
175+
Sends an agent message, usually in response to a received
176+
`HumanChatMessage`.
177+
"""
178+
persona = self.config_manager.persona
179+
162180
agent_msg = AgentChatMessage(
163181
id=uuid4().hex,
164182
time=time.time(),
165183
body=response,
166184
reply_to=human_msg.id if human_msg else "",
185+
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
167186
)
168187

169188
for handler in self._root_chat_handlers.values():

packages/jupyter-ai/jupyter_ai/chat_handlers/default.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
44
from jupyter_ai_magics.providers import BaseProvider
5-
from langchain.chains import ConversationChain
5+
from langchain.chains import ConversationChain, LLMChain
66
from langchain.memory import ConversationBufferWindowMemory
77

88
from .base import BaseChatHandler, SlashCommandRoutingType
@@ -30,14 +30,18 @@ def create_llm_chain(
3030
llm = provider(**unified_parameters)
3131

3232
prompt_template = llm.get_chat_prompt_template()
33+
self.llm = llm
3334
self.memory = ConversationBufferWindowMemory(
3435
return_messages=llm.is_chat_provider, k=2
3536
)
3637

37-
self.llm = llm
38-
self.llm_chain = ConversationChain(
39-
llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
40-
)
38+
if llm.manages_history:
39+
self.llm_chain = LLMChain(llm=llm, prompt=prompt_template, verbose=True)
40+
41+
else:
42+
self.llm_chain = ConversationChain(
43+
llm=llm, prompt=prompt_template, verbose=True, memory=self.memory
44+
)
4145

4246
def clear_memory(self):
4347
# clear chain memory

packages/jupyter-ai/jupyter_ai/chat_handlers/help.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from uuid import uuid4
44

55
from jupyter_ai.models import AgentChatMessage, HumanChatMessage
6+
from jupyter_ai_magics import Persona
67

78
from .base import BaseChatHandler, SlashCommandRoutingType
89

9-
HELP_MESSAGE = """Hi there! I'm Jupyternaut, your programming assistant.
10+
HELP_MESSAGE = """Hi there! I'm {persona_name}, your programming assistant.
1011
You can ask me a question using the text box below. You can also use these commands:
1112
{commands}
1213
@@ -15,23 +16,36 @@
1516
"""
1617

1718

18-
def _format_help_message(chat_handlers: Dict[str, BaseChatHandler]):
19+
def _format_help_message(
20+
chat_handlers: Dict[str, BaseChatHandler],
21+
persona: Persona,
22+
unsupported_slash_commands: set,
23+
):
24+
if unsupported_slash_commands:
25+
keys = set(chat_handlers.keys()) - unsupported_slash_commands
26+
chat_handlers = {key: chat_handlers[key] for key in keys}
27+
1928
commands = "\n".join(
2029
[
2130
f"* `{command_name}` — {handler.help}"
2231
for command_name, handler in chat_handlers.items()
2332
if command_name != "default"
2433
]
2534
)
26-
return HELP_MESSAGE.format(commands=commands)
35+
return HELP_MESSAGE.format(commands=commands, persona_name=persona.name)
2736

2837

29-
def HelpMessage(chat_handlers: Dict[str, BaseChatHandler]):
38+
def build_help_message(
39+
chat_handlers: Dict[str, BaseChatHandler],
40+
persona: Persona,
41+
unsupported_slash_commands: set,
42+
):
3043
return AgentChatMessage(
3144
id=uuid4().hex,
3245
time=time.time(),
33-
body=_format_help_message(chat_handlers),
46+
body=_format_help_message(chat_handlers, persona, unsupported_slash_commands),
3447
reply_to="",
48+
persona=Persona(name=persona.name, avatar_route=persona.avatar_route),
3549
)
3650

3751

packages/jupyter-ai/jupyter_ai/config_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from deepmerge import always_merger as Merger
99
from jsonschema import Draft202012Validator as Validator
1010
from jupyter_ai.models import DescribeConfigResponse, GlobalConfig, UpdateConfigRequest
11+
from jupyter_ai_magics import JupyternautPersona, Persona
1112
from jupyter_ai_magics.utils import (
1213
AnyProvider,
1314
EmProvidersDict,
@@ -452,3 +453,14 @@ def em_provider_params(self):
452453
"model_id": em_lid,
453454
**authn_fields,
454455
}
456+
457+
@property
458+
def persona(self) -> Persona:
459+
"""
460+
The current agent persona, set by the selected LM provider. If the
461+
selected LM provider is `None`, this property returns
462+
`JupyternautPersona` by default.
463+
"""
464+
lm_provider = self.lm_provider
465+
persona = getattr(lm_provider, "persona", None) or JupyternautPersona
466+
return persona

packages/jupyter-ai/jupyter_ai/extension.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
import logging
1+
import os
22
import re
33
import time
44

55
from dask.distributed import Client as DaskClient
66
from importlib_metadata import entry_points
77
from jupyter_ai.chat_handlers.learn import Retriever
8+
from jupyter_ai_magics import JupyternautPersona
89
from jupyter_ai_magics.utils import get_em_providers, get_lm_providers
910
from jupyter_server.extension.application import ExtensionApp
11+
from tornado.web import StaticFileHandler
1012
from traitlets import Dict, List, Unicode
1113

1214
from .chat_handlers import (
@@ -18,7 +20,7 @@
1820
HelpChatHandler,
1921
LearnChatHandler,
2022
)
21-
from .chat_handlers.help import HelpMessage
23+
from .chat_handlers.help import build_help_message
2224
from .completions.handlers import DefaultInlineCompletionHandler
2325
from .config_manager import ConfigManager
2426
from .handlers import (
@@ -30,6 +32,11 @@
3032
RootChatHandler,
3133
)
3234

35+
JUPYTERNAUT_AVATAR_ROUTE = JupyternautPersona.avatar_route
36+
JUPYTERNAUT_AVATAR_PATH = str(
37+
os.path.join(os.path.dirname(__file__), "static", "jupyternaut.svg")
38+
)
39+
3340

3441
class AiExtension(ExtensionApp):
3542
name = "jupyter_ai"
@@ -41,6 +48,14 @@ class AiExtension(ExtensionApp):
4148
(r"api/ai/providers?", ModelProviderHandler),
4249
(r"api/ai/providers/embeddings?", EmbeddingsModelProviderHandler),
4350
(r"api/ai/completion/inline/?", DefaultInlineCompletionHandler),
51+
# serve the default persona avatar at this path.
52+
# the `()` at the end of the URL denotes an empty regex capture group,
53+
# required by Tornado.
54+
(
55+
rf"{JUPYTERNAUT_AVATAR_ROUTE}()",
56+
StaticFileHandler,
57+
{"path": JUPYTERNAUT_AVATAR_PATH},
58+
),
4459
]
4560

4661
allowed_providers = List(
@@ -303,14 +318,36 @@ def initialize_settings(self):
303318
# Make help always appear as the last command
304319
jai_chat_handlers["/help"] = help_chat_handler
305320

306-
self.settings["chat_history"].append(
307-
HelpMessage(chat_handlers=jai_chat_handlers)
308-
)
321+
# bind chat handlers to settings
309322
self.settings["jai_chat_handlers"] = jai_chat_handlers
310323

324+
# show help message at server start
325+
self._show_help_message()
326+
311327
latency_ms = round((time.time() - start) * 1000)
312328
self.log.info(f"Initialized Jupyter AI server extension in {latency_ms} ms.")
313329

330+
def _show_help_message(self):
331+
"""
332+
Method that ensures a dynamically-generated help message is included in
333+
the chat history shown to users.
334+
"""
335+
chat_handlers = self.settings["jai_chat_handlers"]
336+
config_manager: ConfigManager = self.settings["jai_config_manager"]
337+
lm_provider = config_manager.lm_provider
338+
339+
if not lm_provider:
340+
return
341+
342+
persona = config_manager.persona
343+
unsupported_slash_commands = (
344+
lm_provider.unsupported_slash_commands if lm_provider else set()
345+
)
346+
help_message = build_help_message(
347+
chat_handlers, persona, unsupported_slash_commands
348+
)
349+
self.settings["chat_history"].append(help_message)
350+
314351
async def _get_dask_client(self):
315352
return DaskClient(processes=False, asynchronous=True)
316353

packages/jupyter-ai/jupyter_ai/models.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List, Literal, Optional, Union
22

3+
from jupyter_ai_magics import Persona
34
from jupyter_ai_magics.providers import AuthStrategy, Field
45
from langchain.pydantic_v1 import BaseModel, validator
56

@@ -34,8 +35,18 @@ class AgentChatMessage(BaseModel):
3435
id: str
3536
time: float
3637
body: str
37-
# message ID of the HumanChatMessage it is replying to
38+
3839
reply_to: str
40+
"""
41+
Message ID of the HumanChatMessage being replied to. This is set to an empty
42+
string if not applicable.
43+
"""
44+
45+
persona: Persona
46+
"""
47+
The persona of the selected provider. If the selected provider is `None`,
48+
this defaults to a description of `JupyternautPersona`.
49+
"""
3950

4051

4152
class HumanChatMessage(BaseModel):
Lines changed: 9 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)