diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml index e02c5517fe1f3c..4eb56bbc0e916e 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml index e20b8c4960734c..81822b162e6a16 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 3a5a42ba05b44b..a3b216ec1289ae 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,7 +1,7 @@ import base64 import io import json -from collections.abc import Generator +from collections.abc import Generator, Sequence from typing import Optional, Union, cast import anthropic @@ -21,9 +21,9 @@ from PIL import Image from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -33,6 +33,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -86,10 +87,10 @@ def _chat_generate( self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: @@ -130,9 +131,17 @@ def _chat_generate( # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get("max_tokens") > 4096: + if model_parameters.get("max_tokens", 0) > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" + if any( + isinstance(content, DocumentPromptMessageContent) + for prompt_message in prompt_messages + if isinstance(prompt_message.content, list) + for content in prompt_message.content + ): + extra_headers["anthropic-beta"] = "pdfs-2024-09-25" + if tools: extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( @@ -505,6 +514,21 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) + elif isinstance(message_content, DocumentPromptMessageContent): + if message_content.mime_type != "application/pdf": + raise ValueError( + f"Unsupported document type {message_content.mime_type}, " + "only support application/pdf" + ) + sub_message_dict = { + "type": "document", + "source": { + "type": message_content.encode_format, + "media_type": message_content.mime_type, + "data": message_content.data, + }, + } + sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message)