Skip to content

Commit b02e1d2

Browse files
committed
Upgraded LangChain, fixed prompts for Bedrock
1 parent 8f3bfe0 commit b02e1d2

File tree

9 files changed

+160
-27
lines changed

9 files changed

+160
-27
lines changed

docs/source/users/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ Jupyter AI supports the following model providers:
116116
|---------------------|----------------------|----------------------------|---------------------------------|
117117
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
118118
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` |
119+
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `anthropic` |
119120
| Bedrock | `amazon-bedrock` | N/A | `boto3` |
120121
| Cohere | `cohere` | `COHERE_API_KEY` | `cohere` |
121122
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
@@ -464,6 +465,7 @@ We currently support the following language model providers:
464465

465466
- `ai21`
466467
- `anthropic`
468+
- `anthropic-chat`
467469
- `cohere`
468470
- `huggingface_hub`
469471
- `openai`

packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
AnthropicProvider,
1616
AzureChatOpenAIProvider,
1717
BaseProvider,
18+
BedrockChatProvider,
1819
BedrockProvider,
20+
ChatAnthropicProvider,
1921
ChatOpenAINewProvider,
2022
ChatOpenAIProvider,
2123
CohereProvider,

packages/jupyter-ai-magics/jupyter_ai_magics/magics.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from IPython.display import HTML, JSON, Markdown, Math
1414
from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers
1515
from langchain.chains import LLMChain
16+
from langchain.schema import HumanMessage
1617

1718
from .parsers import (
1819
CellArgs,
@@ -125,6 +126,12 @@ def __init__(self, shell):
125126
"no longer supported. Instead, please use: "
126127
"`from langchain.chat_models import ChatOpenAI`",
127128
)
129+
# suppress warning when using old Anthropic provider
130+
warnings.filterwarnings(
131+
"ignore",
132+
message="This Anthropic LLM is deprecated. Please use "
133+
"`from langchain.chat_models import ChatAnthropic` instead",
134+
)
128135

129136
self.providers = get_lm_providers()
130137

@@ -410,6 +417,9 @@ def _get_provider(self, provider_id: Optional[str]) -> BaseProvider:
410417

411418
return self.providers[provider_id]
412419

420+
def _is_chat_model(self, provider_id: str) -> bool:
421+
return provider_id in ["anthropic-chat", "bedrock-chat"]
422+
413423
def display_output(self, output, display_format, md):
414424
# build output display
415425
DisplayClass = DISPLAYS_BY_FORMAT[display_format]
@@ -529,8 +539,12 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
529539
ip = get_ipython()
530540
prompt = prompt.format_map(FormatDict(ip.user_ns))
531541

532-
# generate output from model via provider
533-
result = provider.generate([prompt])
542+
if self._is_chat_model(provider.id):
543+
result = provider.generate([[HumanMessage(content=prompt)]])
544+
else:
545+
# generate output from model via provider
546+
result = provider.generate([prompt])
547+
534548
output = result.generations[0][0].text
535549

536550
# if openai-chat, append exchange to transcript

packages/jupyter-ai-magics/jupyter_ai_magics/providers.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union
99

1010
from jsonpath_ng import parse
11-
from langchain import PromptTemplate
12-
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
11+
from langchain.chat_models import (
12+
AzureChatOpenAI,
13+
BedrockChat,
14+
ChatAnthropic,
15+
ChatOpenAI,
16+
)
1317
from langchain.llms import (
1418
AI21,
1519
Anthropic,
@@ -23,6 +27,8 @@
2327
)
2428
from langchain.llms.sagemaker_endpoint import LLMContentHandler
2529
from langchain.llms.utils import enforce_stop_tokens
30+
from langchain.prompts import PromptTemplate
31+
from langchain.schema import LLMResult
2632
from langchain.utils import get_from_dict_or_env
2733
from pydantic import BaseModel, Extra, root_validator
2834

@@ -187,6 +193,18 @@ async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
187193
_call_with_args = functools.partial(self._call, *args, **kwargs)
188194
return await loop.run_in_executor(executor, _call_with_args)
189195

196+
async def _generate_in_executor(
197+
self, *args, **kwargs
198+
) -> Coroutine[Any, Any, LLMResult]:
199+
"""
200+
Calls self._call() asynchronously in a separate thread for providers
201+
without an async implementation. Requires the event loop to be running.
202+
"""
203+
executor = ThreadPoolExecutor(max_workers=1)
204+
loop = asyncio.get_running_loop()
205+
_call_with_args = functools.partial(self._generate, *args, **kwargs)
206+
return await loop.run_in_executor(executor, _call_with_args)
207+
190208
def update_prompt_template(self, format: str, template: str):
191209
"""
192210
Changes the class-level prompt template for a given format.
@@ -235,8 +253,28 @@ class AnthropicProvider(BaseProvider, Anthropic):
235253
"claude-v1.0",
236254
"claude-v1.2",
237255
"claude-2",
256+
"claude-2.0",
257+
"claude-instant-v1",
258+
"claude-instant-v1.0",
259+
"claude-instant-v1.2",
260+
]
261+
model_id_key = "model"
262+
pypi_package_deps = ["anthropic"]
263+
auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY")
264+
265+
266+
class ChatAnthropicProvider(BaseProvider, ChatAnthropic):
267+
id = "anthropic-chat"
268+
name = "ChatAnthropic"
269+
models = [
270+
"claude-v1",
271+
"claude-v1.0",
272+
"claude-v1.2",
273+
"claude-2",
274+
"claude-2.0",
238275
"claude-instant-v1",
239276
"claude-instant-v1.0",
277+
"claude-instant-v1.2",
240278
]
241279
model_id_key = "model"
242280
pypi_package_deps = ["anthropic"]
@@ -576,16 +614,56 @@ class BedrockProvider(BaseProvider, Bedrock):
576614
id = "bedrock"
577615
name = "Amazon Bedrock"
578616
models = [
579-
"amazon.titan-tg1-large",
617+
"amazon.titan-text-express-v1",
580618
"anthropic.claude-v1",
619+
"anthropic.claude-v2",
581620
"anthropic.claude-instant-v1",
621+
"ai21.j2-ultra-v1",
622+
"ai21.j2-mid-v1",
623+
"cohere.command-text-v14",
624+
]
625+
model_id_key = "model_id"
626+
pypi_package_deps = ["boto3"]
627+
auth_strategy = AwsAuthStrategy()
628+
fields = [
629+
TextField(
630+
key="credentials_profile_name",
631+
label="AWS profile (optional)",
632+
format="text",
633+
),
634+
TextField(key="region_name", label="Region name (optional)", format="text"),
635+
]
636+
637+
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
638+
return await self._call_in_executor(*args, **kwargs)
639+
640+
641+
class BedrockChatProvider(BaseProvider, BedrockChat):
642+
id = "bedrock-chat"
643+
name = "Amazon Bedrock Chat"
644+
models = [
645+
"amazon.titan-text-express-v1",
646+
"anthropic.claude-v1",
582647
"anthropic.claude-v2",
583-
"ai21.j2-jumbo-instruct",
584-
"ai21.j2-grande-instruct",
648+
"anthropic.claude-instant-v1",
649+
"ai21.j2-ultra-v1",
650+
"ai21.j2-mid-v1",
651+
"cohere.command-text-v14",
585652
]
586653
model_id_key = "model_id"
587654
pypi_package_deps = ["boto3"]
588655
auth_strategy = AwsAuthStrategy()
656+
fields = [
657+
TextField(
658+
key="credentials_profile_name",
659+
label="AWS profile (optional)",
660+
format="text",
661+
),
662+
TextField(key="region_name", label="Region name (optional)", format="text"),
663+
]
589664

590665
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
591666
return await self._call_in_executor(*args, **kwargs)
667+
668+
async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
669+
return await self._generate_in_executor(*args, **kwargs)

packages/jupyter-ai-magics/pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies = [
2424
"ipython",
2525
"pydantic~=1.0",
2626
"importlib_metadata>=5.2.0",
27-
"langchain==0.0.277",
27+
"langchain==0.0.306",
2828
"typing_extensions>=4.5.0",
2929
"click~=8.0",
3030
"jsonpath-ng>=1.5.3,<2",
@@ -44,7 +44,7 @@ test = [
4444

4545
all = [
4646
"ai21",
47-
"anthropic~=0.2.10",
47+
"anthropic~=0.3.0",
4848
"cohere",
4949
"gpt4all",
5050
"huggingface_hub",
@@ -66,6 +66,8 @@ openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider"
6666
azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider"
6767
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
6868
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
69+
anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"
70+
amazon-bedrock-chat = "jupyter_ai_magics:BedrockChatProvider"
6971

7072
[project.entry-points."jupyter_ai.embeddings_model_providers"]
7173
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"

packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class AskChatHandler(BaseChatHandler):
1212
"""Processes messages prefixed with /ask. This actor will
1313
send the message as input to a RetrieverQA chain, that
14-
follows the Retrieval and Generation (RAG) tehnique to
14+
follows the Retrieval and Generation (RAG) technique to
1515
query the documents from the index, and sends this context
1616
to the LLM to generate the final reply.
1717
"""
@@ -29,7 +29,7 @@ def create_llm_chain(
2929
self.llm = provider(**provider_params)
3030
self.chat_history = []
3131
self.llm_chain = ConversationalRetrievalChain.from_llm(
32-
self.llm, self._retriever
32+
self.llm, self._retriever, verbose=True
3333
)
3434

3535
async def _process_message(self, message: HumanChatMessage):

packages/jupyter-ai/jupyter_ai/chat_handlers/default.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1-
from typing import Dict, List, Type
1+
from typing import Any, Dict, List, Type
22

33
from jupyter_ai.models import ChatMessage, ClearMessage, HumanChatMessage
4-
from jupyter_ai_magics.providers import BaseProvider
5-
from langchain import ConversationChain
4+
from jupyter_ai_magics.providers import (
5+
BaseProvider,
6+
BedrockChatProvider,
7+
BedrockProvider,
8+
)
9+
from langchain.chains import ConversationChain
610
from langchain.memory import ConversationBufferWindowMemory
711
from langchain.prompts import (
812
ChatPromptTemplate,
913
HumanMessagePromptTemplate,
1014
MessagesPlaceholder,
1115
SystemMessagePromptTemplate,
1216
)
13-
from langchain.schema import AIMessage
17+
from langchain.schema import AIMessage, ChatMessage
18+
from langchain.schema.messages import BaseMessage
1419

1520
from .base import BaseChatHandler
1621

@@ -26,6 +31,20 @@
2631
""".strip()
2732

2833

34+
class HistoryPlaceholderTemplate(MessagesPlaceholder):
35+
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
36+
values = super().format_messages(**kwargs)
37+
corrected_values = []
38+
for v in values:
39+
if isinstance(v, AIMessage):
40+
corrected_values.append(
41+
ChatMessage(role="Assistant", content=v.content)
42+
)
43+
else:
44+
corrected_values.append(v)
45+
return corrected_values
46+
47+
2948
class DefaultChatHandler(BaseChatHandler):
3049
def __init__(self, chat_history: List[ChatMessage], *args, **kwargs):
3150
super().__init__(*args, **kwargs)
@@ -36,16 +55,32 @@ def create_llm_chain(
3655
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
3756
):
3857
llm = provider(**provider_params)
39-
prompt_template = ChatPromptTemplate.from_messages(
40-
[
41-
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(
42-
provider_name=llm.name, local_model_id=llm.model_id
43-
),
44-
MessagesPlaceholder(variable_name="history"),
45-
HumanMessagePromptTemplate.from_template("{input}"),
46-
AIMessage(content=""),
47-
]
48-
)
58+
if provider == BedrockChatProvider or provider == BedrockProvider:
59+
prompt_template = ChatPromptTemplate.from_messages(
60+
[
61+
ChatMessage(
62+
role="Instructions",
63+
content=SYSTEM_PROMPT.format(
64+
provider_name=llm.name, local_model_id=llm.model_id
65+
),
66+
),
67+
HistoryPlaceholderTemplate(variable_name="history"),
68+
HumanMessagePromptTemplate.from_template("{input}"),
69+
ChatMessage(role="Assistant", content=""),
70+
]
71+
)
72+
else:
73+
prompt_template = ChatPromptTemplate.from_messages(
74+
[
75+
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format(
76+
provider_name=llm.name, local_model_id=llm.model_id
77+
),
78+
MessagesPlaceholder(variable_name="history"),
79+
HumanMessagePromptTemplate.from_template("{input}"),
80+
AIMessage(content=""),
81+
]
82+
)
83+
4984
self.llm = llm
5085
self.llm_chain = ConversationChain(
5186
llm=llm, prompt=prompt_template, verbose=True, memory=self.memory

packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
IndexMetadata,
1616
)
1717
from jupyter_core.paths import jupyter_data_dir
18-
from langchain import FAISS
1918
from langchain.schema import BaseRetriever, Document
2019
from langchain.text_splitter import (
2120
LatexTextSplitter,
2221
MarkdownTextSplitter,
2322
PythonCodeTextSplitter,
2423
RecursiveCharacterTextSplitter,
2524
)
25+
from langchain.vectorstores import FAISS
2626

2727
from .base import BaseChatHandler
2828

packages/jupyter-ai/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies = [
2828
"openai~=0.26",
2929
"aiosqlite>=0.18",
3030
"importlib_metadata>=5.2.0",
31-
"langchain==0.0.277",
31+
"langchain==0.0.306",
3232
"tiktoken", # required for OpenAIEmbeddings
3333
"jupyter_ai_magics",
3434
"dask[distributed]",

0 commit comments

Comments
 (0)