|
8 | 8 | from typing import Any, ClassVar, Coroutine, Dict, List, Literal, Optional, Union |
9 | 9 |
|
10 | 10 | from jsonpath_ng import parse |
11 | | -from langchain import PromptTemplate |
12 | | -from langchain.chat_models import AzureChatOpenAI, ChatOpenAI |
| 11 | +from langchain.chat_models import ( |
| 12 | + AzureChatOpenAI, |
| 13 | + BedrockChat, |
| 14 | + ChatAnthropic, |
| 15 | + ChatOpenAI, |
| 16 | +) |
13 | 17 | from langchain.llms import ( |
14 | 18 | AI21, |
15 | 19 | Anthropic, |
|
23 | 27 | ) |
24 | 28 | from langchain.llms.sagemaker_endpoint import LLMContentHandler |
25 | 29 | from langchain.llms.utils import enforce_stop_tokens |
| 30 | +from langchain.prompts import PromptTemplate |
| 31 | +from langchain.schema import LLMResult |
26 | 32 | from langchain.utils import get_from_dict_or_env |
27 | 33 | from pydantic import BaseModel, Extra, root_validator |
28 | 34 |
|
@@ -187,6 +193,18 @@ async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]: |
187 | 193 | _call_with_args = functools.partial(self._call, *args, **kwargs) |
188 | 194 | return await loop.run_in_executor(executor, _call_with_args) |
189 | 195 |
|
| 196 | + async def _generate_in_executor( |
| 197 | + self, *args, **kwargs |
| 198 | + ) -> Coroutine[Any, Any, LLMResult]: |
| 199 | + """ |
| 200 | + Calls self._call() asynchronously in a separate thread for providers |
| 201 | + without an async implementation. Requires the event loop to be running. |
| 202 | + """ |
| 203 | + executor = ThreadPoolExecutor(max_workers=1) |
| 204 | + loop = asyncio.get_running_loop() |
| 205 | + _call_with_args = functools.partial(self._generate, *args, **kwargs) |
| 206 | + return await loop.run_in_executor(executor, _call_with_args) |
| 207 | + |
190 | 208 | def update_prompt_template(self, format: str, template: str): |
191 | 209 | """ |
192 | 210 | Changes the class-level prompt template for a given format. |
@@ -235,8 +253,28 @@ class AnthropicProvider(BaseProvider, Anthropic): |
235 | 253 | "claude-v1.0", |
236 | 254 | "claude-v1.2", |
237 | 255 | "claude-2", |
| 256 | + "claude-2.0", |
| 257 | + "claude-instant-v1", |
| 258 | + "claude-instant-v1.0", |
| 259 | + "claude-instant-v1.2", |
| 260 | + ] |
| 261 | + model_id_key = "model" |
| 262 | + pypi_package_deps = ["anthropic"] |
| 263 | + auth_strategy = EnvAuthStrategy(name="ANTHROPIC_API_KEY") |
| 264 | + |
| 265 | + |
| 266 | +class ChatAnthropicProvider(BaseProvider, ChatAnthropic): |
| 267 | + id = "anthropic-chat" |
| 268 | + name = "ChatAnthropic" |
| 269 | + models = [ |
| 270 | + "claude-v1", |
| 271 | + "claude-v1.0", |
| 272 | + "claude-v1.2", |
| 273 | + "claude-2", |
| 274 | + "claude-2.0", |
238 | 275 | "claude-instant-v1", |
239 | 276 | "claude-instant-v1.0", |
| 277 | + "claude-instant-v1.2", |
240 | 278 | ] |
241 | 279 | model_id_key = "model" |
242 | 280 | pypi_package_deps = ["anthropic"] |
@@ -576,16 +614,56 @@ class BedrockProvider(BaseProvider, Bedrock): |
576 | 614 | id = "bedrock" |
577 | 615 | name = "Amazon Bedrock" |
578 | 616 | models = [ |
579 | | - "amazon.titan-tg1-large", |
| 617 | + "amazon.titan-text-express-v1", |
580 | 618 | "anthropic.claude-v1", |
| 619 | + "anthropic.claude-v2", |
581 | 620 | "anthropic.claude-instant-v1", |
| 621 | + "ai21.j2-ultra-v1", |
| 622 | + "ai21.j2-mid-v1", |
| 623 | + "cohere.command-text-v14", |
| 624 | + ] |
| 625 | + model_id_key = "model_id" |
| 626 | + pypi_package_deps = ["boto3"] |
| 627 | + auth_strategy = AwsAuthStrategy() |
| 628 | + fields = [ |
| 629 | + TextField( |
| 630 | + key="credentials_profile_name", |
| 631 | + label="AWS profile (optional)", |
| 632 | + format="text", |
| 633 | + ), |
| 634 | + TextField(key="region_name", label="Region name (optional)", format="text"), |
| 635 | + ] |
| 636 | + |
| 637 | + async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: |
| 638 | + return await self._call_in_executor(*args, **kwargs) |
| 639 | + |
| 640 | + |
| 641 | +class BedrockChatProvider(BaseProvider, BedrockChat): |
| 642 | + id = "bedrock-chat" |
| 643 | + name = "Amazon Bedrock Chat" |
| 644 | + models = [ |
| 645 | + "amazon.titan-text-express-v1", |
| 646 | + "anthropic.claude-v1", |
582 | 647 | "anthropic.claude-v2", |
583 | | - "ai21.j2-jumbo-instruct", |
584 | | - "ai21.j2-grande-instruct", |
| 648 | + "anthropic.claude-instant-v1", |
| 649 | + "ai21.j2-ultra-v1", |
| 650 | + "ai21.j2-mid-v1", |
| 651 | + "cohere.command-text-v14", |
585 | 652 | ] |
586 | 653 | model_id_key = "model_id" |
587 | 654 | pypi_package_deps = ["boto3"] |
588 | 655 | auth_strategy = AwsAuthStrategy() |
| 656 | + fields = [ |
| 657 | + TextField( |
| 658 | + key="credentials_profile_name", |
| 659 | + label="AWS profile (optional)", |
| 660 | + format="text", |
| 661 | + ), |
| 662 | + TextField(key="region_name", label="Region name (optional)", format="text"), |
| 663 | + ] |
589 | 664 |
|
590 | 665 | async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: |
591 | 666 | return await self._call_in_executor(*args, **kwargs) |
| 667 | + |
| 668 | + async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]: |
| 669 | + return await self._generate_in_executor(*args, **kwargs) |
0 commit comments