Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 90 additions & 127 deletions functions/pipes/anthropic/main.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
"""
title: Anthropic Manifold Pipe
authors: justinh-rahb, christian-taillon, jfbloom22
authors: justinh-rahb, christian-taillon, jfbloom22, aaronchan0
author_url: https://github.com/justinh-rahb
funding_url: https://github.com/open-webui
version: 0.3.0
version: 0.4.0
required_open_webui_version: 0.3.17
license: MIT
"""

import os
import requests
import json
import time
from typing import List, Union, Generator, Iterator, Optional, Dict
import json
from typing import List, Union, AsyncGenerator, Iterator, Optional, Dict
from pydantic import BaseModel, Field
from open_webui.utils.misc import pop_system_message
from anthropic import AsyncAnthropic
from anthropic.types import TextBlock


class Pipe:
class Valves(BaseModel):
ANTHROPIC_API_KEY: str = Field(default="")
ANTHROPIC_ENABLE_WEB_SEARCH: bool = Field(default=False)

def __init__(self):
self.type = "manifold"
self.id = "anthropic"
self.name = "anthropic/"
self.name = "" # anthropic/"
self.valves = self.Valves(
**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "")}
**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", ""), "ANTHROPIC_ENABLE_WEB_SEARCH": os.getenv("ANTHROPIC_ENABLE_WEB_SEARCH", False)}
)
self.MAX_IMAGE_SIZE = 5 * 1024 * 1024 # 5MB per image

self.client: AsyncAnthropic = AsyncAnthropic(api_key=self.valves.ANTHROPIC_API_KEY)

# Model cache
self._model_cache: Optional[List[Dict[str, str]]] = None
self._model_cache_time: float = 0
self._cache_ttl = int(os.getenv("ANTHROPIC_MODEL_CACHE_TTL", "600"))

def get_anthropic_models_from_api(self, force_refresh: bool = False) -> List[Dict[str, str]]:
def get_client(self) -> AsyncAnthropic:
if self.client.api_key != self.valves.ANTHROPIC_API_KEY:
self.client: AsyncAnthropic = AsyncAnthropic(api_key=self.valves.ANTHROPIC_API_KEY)
return self.client

async def get_anthropic_models_from_api(self, force_refresh: bool = False) -> List[Dict[str, str]]:
"""
Retrieve available Anthropic models from the API.
Uses caching to reduce API calls.
Expand All @@ -48,11 +57,7 @@ def get_anthropic_models_from_api(self, force_refresh: bool = False) -> List[Dic
"""
# Check cache first
current_time = time.time()
if (
not force_refresh
and self._model_cache is not None
and (current_time - self._model_cache_time) < self._cache_ttl
):
if not force_refresh and self._model_cache is not None and (current_time - self._model_cache_time) < self._cache_ttl:
return self._model_cache

if not self.valves.ANTHROPIC_API_KEY:
Expand All @@ -64,50 +69,32 @@ def get_anthropic_models_from_api(self, force_refresh: bool = False) -> List[Dic
]

try:
headers = {
"x-api-key": self.valves.ANTHROPIC_API_KEY,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}

response = requests.get(
"https://api.anthropic.com/v1/models",
headers=headers,
timeout=10
)

if response.status_code != 200:
raise Exception(f"HTTP Error {response.status_code}: {response.text}")

data = response.json()
models = []

for model in data.get("data", []):
models.append({
"id": model["id"],
"name": model.get("display_name", model["id"]),
})

anthropic_models = await self.get_client().models.list()
models = [{"id": model.id, "name": model.display_name} for model in anthropic_models.data]

# Update cache
self._model_cache = models
self._model_cache_time = current_time

return models

except Exception as e:
print(f"Error fetching Anthropic models: {e}")
return [
{"id": "error", "name": f"Could not fetch models from Anthropic: {str(e)}"}
{
"id": "error",
"name": f"Could not fetch models from Anthropic: {str(e)}",
}
]

def get_anthropic_models(self) -> List[Dict[str, str]]:
async def get_anthropic_models(self) -> List[Dict[str, str]]:
"""
Get Anthropic models from the API.
"""
return self.get_anthropic_models_from_api()
return await self.get_anthropic_models_from_api()

def pipes(self) -> List[dict]:
return self.get_anthropic_models()
async def pipes(self) -> List[dict]:
return await self.get_anthropic_models()

def process_image(self, image_data):
"""Process image data with size validation."""
Expand All @@ -118,9 +105,7 @@ def process_image(self, image_data):
# Check base64 image size
image_size = len(base64_data) * 3 / 4 # Convert base64 size to bytes
if image_size > self.MAX_IMAGE_SIZE:
raise ValueError(
f"Image size exceeds 5MB limit: {image_size / (1024 * 1024):.2f}MB"
)
raise ValueError(f"Image size exceeds 5MB limit: {image_size / (1024 * 1024):.2f}MB")

return {
"type": "image",
Expand All @@ -137,16 +122,14 @@ def process_image(self, image_data):
content_length = int(response.headers.get("content-length", 0))

if content_length > self.MAX_IMAGE_SIZE:
raise ValueError(
f"Image at URL exceeds 5MB limit: {content_length / (1024 * 1024):.2f}MB"
)
raise ValueError(f"Image at URL exceeds 5MB limit: {content_length / (1024 * 1024):.2f}MB")

return {
"type": "image",
"source": {"type": "url", "url": url},
}

def pipe(self, body: dict) -> Union[str, Generator, Iterator]:
async def pipe(self, body: dict) -> Union[str, AsyncGenerator, Iterator]:
system_message, messages = pop_system_message(body["messages"])

processed_messages = []
Expand All @@ -166,108 +149,88 @@ def pipe(self, body: dict) -> Union[str, Generator, Iterator]:
if processed_image["source"]["type"] == "base64":
image_size = len(processed_image["source"]["data"]) * 3 / 4
total_image_size += image_size
if (
total_image_size > 100 * 1024 * 1024
): # 100MB total limit
raise ValueError(
"Total size of images exceeds 100 MB limit"
)
if total_image_size > 100 * 1024 * 1024: # 100MB total limit
raise ValueError("Total size of images exceeds 100 MB limit")
else:
processed_content = [
{"type": "text", "text": message.get("content", "")}
]
processed_content = [{"type": "text", "text": message.get("content", "")}]

processed_messages.append(
{"role": message["role"], "content": processed_content}
)
processed_messages.append({"role": message["role"], "content": processed_content})

payload = {
"model": body["model"][body["model"].find(".") + 1 :],
"messages": processed_messages,
"max_tokens": body.get("max_tokens", 4096),
"temperature": body.get("temperature", 0.8),
"top_k": body.get("top_k", 40),
"top_p": body.get("top_p", 0.9),
# "top_p": body.get("top_p", 0.9),
"stop_sequences": body.get("stop", []),
**({"system": str(system_message)} if system_message else {}),
"stream": body.get("stream", False),
}

headers = {
"x-api-key": self.valves.ANTHROPIC_API_KEY,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}

url = "https://api.anthropic.com/v1/messages"

payload["tools"] = (
[
{
"type": "web_search_20250305",
"name": "web_search",
"max_uses": 5,
"user_location": {
"type": "approximate",
"city": "San Ramon",
"region": "California",
"country": "US",
"timezone": "America/Los_Angeles",
},
}
]
if self.valves.ANTHROPIC_ENABLE_WEB_SEARCH
else []
)
try:
if body.get("stream", False):
return self.stream_response(url, headers, payload)
return self.stream_response(payload)
else:
return self.non_stream_response(url, headers, payload)
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
return f"Error: Request failed: {e}"
return await self.non_stream_response(payload)
except Exception as e:
print(f"Error in pipe method: {e}")
return f"Error: {e}"

def stream_response(self, url, headers, payload):
async def stream_response(self, payload):
try:
with requests.post(
url, headers=headers, json=payload, stream=True, timeout=(3.05, 60)
) as response:
if response.status_code != 200:
raise Exception(
f"HTTP Error {response.status_code}: {response.text}"
)

for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data: "):
try:
data = json.loads(line[6:])
if data["type"] == "content_block_start":
yield data["content_block"]["text"]
elif data["type"] == "content_block_delta":
yield data["delta"]["text"]
elif data["type"] == "message_stop":
break
elif data["type"] == "message":
for content in data.get("content", []):
if content["type"] == "text":
yield content["text"]

time.sleep(
0.01
) # Delay to avoid overwhelming the client

except json.JSONDecodeError:
print(f"Failed to parse JSON: {line}")
except KeyError as e:
print(f"Unexpected data structure: {e}")
print(f"Full data: {data}")
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
yield f"Error: Request failed: {e}"
async with self.get_client().messages.stream(
model=payload["model"], max_tokens=payload["max_tokens"], system=payload["system"], messages=payload["messages"], tools=payload["tools"]
) as stream:
input_json: str = ""
is_thinking: bool = False
async for event in stream:
if event.type == "content_block_start":
if event.content_block.type == "server_tool_use" or event.content_block.type == "tool_use":
if not is_thinking:
is_thinking = True
yield "<think>"
input_json = ""
elif event.content_block.type == "text":
if is_thinking:
yield "</think>"
is_thinking = False
elif event.type == "content_block_stop":
if event.content_block.type == "server_tool_use" or event.content_block.type == "tool_use":
input_params = ", ".join([f"{key}: {value}" for key, value in json.loads(input_json).items()])
yield f"calling {event.content_block.name} with {input_params}\n"
elif event.type == "content_block_delta":
if event.delta.type == "text_delta":
yield event.delta.text
elif event.delta.type == "input_json_delta":
input_json += event.delta.partial_json
except Exception as e:
print(f"General error in stream_response method: {e}")
yield f"Error: {e}"

def non_stream_response(self, url, headers, payload):
async def non_stream_response(self, payload):
try:
response = requests.post(
url, headers=headers, json=payload, timeout=(3.05, 60)
resp = await self.get_client().messages.create(
model=payload["model"], max_tokens=payload["max_tokens"], system=payload["system"], messages=payload["messages"], tools=payload["tools"]
)
if response.status_code != 200:
raise Exception(f"HTTP Error {response.status_code}: {response.text}")

res = response.json()
return (
res["content"][0]["text"] if "content" in res and res["content"] else ""
)
except requests.exceptions.RequestException as e:
return "\n".join([r.text if isinstance(r, TextBlock) else "" for r in resp.content])
except Exception as e:
print(f"Failed non-stream request: {e}")
return f"Error: {e}"