|
1 | | -"""Vision tool for analyzing images using OpenAI's Vision API.""" |
2 | | - |
| 1 | +from typing import IO, Optional |
| 2 | +import os |
| 3 | +from pathlib import Path |
3 | 4 | import base64 |
4 | | -from typing import Optional |
| 5 | +import tempfile |
5 | 6 | import requests |
6 | | -from openai import OpenAI |
| 7 | +import anthropic |
7 | 8 |
|
8 | 9 | __all__ = ["analyze_image"] |
9 | 10 |
|
| 11 | +PROMPT = os.getenv('VISION_PROMPT', "What's in this image?") |
| 12 | +MODEL = os.getenv('VISION_MODEL', "claude-3-5-sonnet-20241022") |
| 13 | +MAX_TOKENS: int = int(os.getenv('VISION_MAX_TOKENS', 1024)) |
10 | 14 |
|
11 | | -def analyze_image(image_path_url: str) -> str: |
12 | | - """ |
13 | | - Analyze an image using OpenAI's Vision API. |
| 15 | +MEDIA_TYPES = { |
| 16 | + "jpg": "image/jpeg", |
| 17 | + "jpeg": "image/jpeg", |
| 18 | + "png": "image/png", |
| 19 | + "gif": "image/gif", |
| 20 | + "webp": "image/webp", |
| 21 | +} |
| 22 | +ALLOWED_MEDIA_TYPES = list(MEDIA_TYPES.keys()) |
14 | 23 |
|
15 | | - Args: |
16 | | - image_path_url: Local path or URL to the image |
| 24 | +# image sizes that will not be resized |
| 25 | +# TODO is there any value in resizing pre-upload? |
| 26 | +# 1:1 1092x1092 px |
| 27 | +# 3:4 951x1268 px |
| 28 | +# 2:3 896x1344 px |
| 29 | +# 9:16 819x1456 px |
| 30 | +# 1:2 784x1568 px |
17 | 31 |
|
18 | | - Returns: |
19 | | - str: Description of the image contents |
20 | | - """ |
21 | | - client = OpenAI() |
22 | 32 |
|
23 | | - if not image_path_url: |
24 | | - return "Image Path or URL is required." |
| 33 | +def _get_media_type(image_filename: str) -> Optional[str]: |
| 34 | + """Get the media type from an image filename.""" |
| 35 | + for ext, media_type in MEDIA_TYPES.items(): |
| 36 | + if image_filename.endswith(ext): |
| 37 | + return media_type |
| 38 | + return None |
| 39 | + |
25 | 40 |
|
26 | | - if "http" in image_path_url: |
27 | | - return _analyze_web_image(client, image_path_url) |
28 | | - return _analyze_local_image(client, image_path_url) |
| 41 | +def _encode_image(image_handle: IO) -> str: |
| 42 | + """Encode a file handle to base64.""" |
| 43 | + return base64.b64encode(image_handle.read()).decode("utf-8") |
29 | 44 |
|
30 | 45 |
|
31 | | -def _analyze_web_image(client: OpenAI, image_path_url: str) -> str: |
32 | | - response = client.chat.completions.create( |
33 | | - model="gpt-4-vision-preview", |
| 46 | +def _make_anthropic_request(image_handle: IO, media_type: str) -> anthropic.types.Message: |
| 47 | + """Make a request to the Anthropic API using an image.""" |
| 48 | + client = anthropic.Anthropic() |
| 49 | + data = _encode_image(image_handle) |
| 50 | + return client.messages.create( |
| 51 | + model=MODEL, |
| 52 | + max_tokens=MAX_TOKENS, |
34 | 53 | messages=[ |
35 | 54 | { |
36 | 55 | "role": "user", |
37 | 56 | "content": [ |
38 | | - {"type": "text", "text": "What's in this image?"}, |
39 | | - {"type": "image_url", "image_url": {"url": image_path_url}}, |
| 57 | + { # type: ignore |
| 58 | + "type": "image", |
| 59 | + "source": { |
| 60 | + "type": "base64", |
| 61 | + "media_type": media_type, |
| 62 | + "data": data, |
| 63 | + }, |
| 64 | + }, |
| 65 | + { # type: ignore |
| 66 | + "type": "text", |
| 67 | + "text": PROMPT, |
| 68 | + }, |
40 | 69 | ], |
41 | 70 | } |
42 | 71 | ], |
43 | | - max_tokens=300, |
44 | 72 | ) |
45 | | - return response.choices[0].message.content # type: ignore[return-value] |
46 | 73 |
|
47 | 74 |
|
48 | | -def _analyze_local_image(client: OpenAI, image_path: str) -> str: |
49 | | - base64_image = _encode_image(image_path) |
50 | | - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {client.api_key}"} |
51 | | - payload = { |
52 | | - "model": "gpt-4-vision-preview", |
53 | | - "messages": [ |
54 | | - { |
55 | | - "role": "user", |
56 | | - "content": [ |
57 | | - {"type": "text", "text": "What's in this image?"}, |
58 | | - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, |
59 | | - ], |
60 | | - } |
61 | | - ], |
62 | | - "max_tokens": 300, |
63 | | - } |
64 | | - response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) |
65 | | - return response.json()["choices"][0]["message"]["content"] |
| 75 | +def _analyze_web_image(image_url: str, media_type: str) -> str: |
| 76 | + """Analyze an image from a URL.""" |
| 77 | + with tempfile.NamedTemporaryFile() as temp_file: |
| 78 | + temp_file.write(requests.get(image_url).content) |
| 79 | + temp_file.flush() |
| 80 | + temp_file.seek(0) |
| 81 | + response = _make_anthropic_request(temp_file, media_type) |
| 82 | + return response.content[0].text # type: ignore |
66 | 83 |
|
67 | 84 |
|
68 | | -def _encode_image(image_path: str) -> str: |
| 85 | +def _analyze_local_image(image_path: str, media_type: str) -> str: |
| 86 | + """Analyze an image from a local file.""" |
69 | 87 | with open(image_path, "rb") as image_file: |
70 | | - return base64.b64encode(image_file.read()).decode("utf-8") |
| 88 | + response = _make_anthropic_request(image_file, media_type) |
| 89 | + return response.content[0].text # type: ignore |
| 90 | + |
| 91 | + |
| 92 | +def analyze_image(image_path_or_url: str) -> str: |
| 93 | + """ |
| 94 | + Analyze an image using OpenAI's Vision API. |
| 95 | +
|
| 96 | + Args: |
| 97 | + image_path_or_url: Local path or URL to the image. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + str: Description of the image contents |
| 101 | + """ |
| 102 | + if not image_path_or_url: |
| 103 | + return "Image Path or URL is required." |
| 104 | + |
| 105 | + media_type = _get_media_type(image_path_or_url) |
| 106 | + if not media_type: |
| 107 | + return f"Unsupported image type use {ALLOWED_MEDIA_TYPES}." |
| 108 | + |
| 109 | + if "http" in image_path_or_url: |
| 110 | + return _analyze_web_image(image_path_or_url, media_type) |
| 111 | + return _analyze_local_image(image_path_or_url, media_type) |
0 commit comments