Skip to content

Commit a4c38e0

Browse files
authored
Merge pull request #296 from tcdent/vision-tool
Update vision tool to use Anthropic
2 parents ca495c9 + da7c83c commit a4c38e0

File tree

4 files changed

+147
-47
lines changed

4 files changed

+147
-47
lines changed
Lines changed: 86 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,111 @@
1-
"""Vision tool for analyzing images using OpenAI's Vision API."""
2-
1+
from typing import IO, Optional
2+
import os
3+
from pathlib import Path
34
import base64
4-
from typing import Optional
5+
import tempfile
56
import requests
6-
from openai import OpenAI
7+
import anthropic
78

89
__all__ = ["analyze_image"]
910

11+
PROMPT = os.getenv('VISION_PROMPT', "What's in this image?")
12+
MODEL = os.getenv('VISION_MODEL', "claude-3-5-sonnet-20241022")
13+
MAX_TOKENS: int = int(os.getenv('VISION_MAX_TOKENS', 1024))
1014

11-
def analyze_image(image_path_url: str) -> str:
12-
"""
13-
Analyze an image using OpenAI's Vision API.
15+
MEDIA_TYPES = {
16+
"jpg": "image/jpeg",
17+
"jpeg": "image/jpeg",
18+
"png": "image/png",
19+
"gif": "image/gif",
20+
"webp": "image/webp",
21+
}
22+
ALLOWED_MEDIA_TYPES = list(MEDIA_TYPES.keys())
1423

15-
Args:
16-
image_path_url: Local path or URL to the image
24+
# image sizes that will not be resized
25+
# TODO is there any value in resizing pre-upload?
26+
# 1:1 1092x1092 px
27+
# 3:4 951x1268 px
28+
# 2:3 896x1344 px
29+
# 9:16 819x1456 px
30+
# 1:2 784x1568 px
1731

18-
Returns:
19-
str: Description of the image contents
20-
"""
21-
client = OpenAI()
2232

23-
if not image_path_url:
24-
return "Image Path or URL is required."
33+
def _get_media_type(image_filename: str) -> Optional[str]:
34+
"""Get the media type from an image filename."""
35+
for ext, media_type in MEDIA_TYPES.items():
36+
if image_filename.endswith(ext):
37+
return media_type
38+
return None
39+
2540

26-
if "http" in image_path_url:
27-
return _analyze_web_image(client, image_path_url)
28-
return _analyze_local_image(client, image_path_url)
41+
def _encode_image(image_handle: IO) -> str:
42+
"""Encode a file handle to base64."""
43+
return base64.b64encode(image_handle.read()).decode("utf-8")
2944

3045

31-
def _analyze_web_image(client: OpenAI, image_path_url: str) -> str:
32-
response = client.chat.completions.create(
33-
model="gpt-4-vision-preview",
46+
def _make_anthropic_request(image_handle: IO, media_type: str) -> anthropic.types.Message:
47+
"""Make a request to the Anthropic API using an image."""
48+
client = anthropic.Anthropic()
49+
data = _encode_image(image_handle)
50+
return client.messages.create(
51+
model=MODEL,
52+
max_tokens=MAX_TOKENS,
3453
messages=[
3554
{
3655
"role": "user",
3756
"content": [
38-
{"type": "text", "text": "What's in this image?"},
39-
{"type": "image_url", "image_url": {"url": image_path_url}},
57+
{ # type: ignore
58+
"type": "image",
59+
"source": {
60+
"type": "base64",
61+
"media_type": media_type,
62+
"data": data,
63+
},
64+
},
65+
{ # type: ignore
66+
"type": "text",
67+
"text": PROMPT,
68+
},
4069
],
4170
}
4271
],
43-
max_tokens=300,
4472
)
45-
return response.choices[0].message.content # type: ignore[return-value]
4673

4774

48-
def _analyze_local_image(client: OpenAI, image_path: str) -> str:
49-
base64_image = _encode_image(image_path)
50-
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {client.api_key}"}
51-
payload = {
52-
"model": "gpt-4-vision-preview",
53-
"messages": [
54-
{
55-
"role": "user",
56-
"content": [
57-
{"type": "text", "text": "What's in this image?"},
58-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
59-
],
60-
}
61-
],
62-
"max_tokens": 300,
63-
}
64-
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
65-
return response.json()["choices"][0]["message"]["content"]
75+
def _analyze_web_image(image_url: str, media_type: str) -> str:
76+
"""Analyze an image from a URL."""
77+
with tempfile.NamedTemporaryFile() as temp_file:
78+
temp_file.write(requests.get(image_url).content)
79+
temp_file.flush()
80+
temp_file.seek(0)
81+
response = _make_anthropic_request(temp_file, media_type)
82+
return response.content[0].text # type: ignore
6683

6784

68-
def _encode_image(image_path: str) -> str:
85+
def _analyze_local_image(image_path: str, media_type: str) -> str:
86+
"""Analyze an image from a local file."""
6987
with open(image_path, "rb") as image_file:
70-
return base64.b64encode(image_file.read()).decode("utf-8")
88+
response = _make_anthropic_request(image_file, media_type)
89+
return response.content[0].text # type: ignore
90+
91+
92+
def analyze_image(image_path_or_url: str) -> str:
93+
"""
94+
Analyze an image using OpenAI's Vision API.
95+
96+
Args:
97+
image_path_or_url: Local path or URL to the image.
98+
99+
Returns:
100+
str: Description of the image contents
101+
"""
102+
if not image_path_or_url:
103+
return "Image Path or URL is required."
104+
105+
media_type = _get_media_type(image_path_or_url)
106+
if not media_type:
107+
return f"Unsupported image type use {ALLOWED_MEDIA_TYPES}."
108+
109+
if "http" in image_path_or_url:
110+
return _analyze_web_image(image_path_or_url, media_type)
111+
return _analyze_local_image(image_path_or_url, media_type)

agentstack/_tools/vision/config.json

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
"name": "vision",
33
"category": "image-analysis",
44
"env": {
5-
"OPENAI_API_KEY": null
5+
"ANTHROPIC_API_KEY": null,
6+
"VISION_PROMPT": null,
7+
"VISION_MODEL": null,
8+
"VISION_MAX_TOKENS": null
69
},
710
"dependencies": [
8-
"openai>=1.0.0",
11+
"anthropic>=0.45.2",
912
"requests>=2.31.0"
1013
],
1114
"tools": ["analyze_image"]

tests/fixtures/test_image.jpg

35.5 KB
Loading

tests/tools/test_tool_vision.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
from pathlib import Path
3+
import unittest
4+
from agentstack._tools import ToolConfig
5+
6+
7+
TEST_IMAGE_PATH: Path = Path(__file__).parent.parent / 'fixtures/test_image.jpg'
8+
9+
10+
class VisionToolTest(unittest.TestCase):
11+
def setUp(self):
12+
tool = ToolConfig.from_tool_name('vision')
13+
for dependency in tool.dependencies:
14+
os.system(f"pip install {dependency}")
15+
16+
try:
17+
from agentstack._tools import vision
18+
except ImportError as e:
19+
self.skipTest(str(e))
20+
21+
def test_get_media_type(self):
22+
from agentstack._tools.vision import _get_media_type
23+
24+
self.assertEqual(_get_media_type("image.jpg"), "image/jpeg")
25+
self.assertEqual(_get_media_type("image.jpeg"), "image/jpeg")
26+
self.assertEqual(_get_media_type("http://google.com/image.png"), "image/png")
27+
self.assertEqual(_get_media_type("/foo/bar/image.gif"), "image/gif")
28+
self.assertEqual(_get_media_type("image.webp"), "image/webp")
29+
self.assertEqual(_get_media_type("document.pdf"), None)
30+
31+
def test_encode_image(self):
32+
from agentstack._tools.vision import _encode_image
33+
34+
with open(TEST_IMAGE_PATH, "rb") as image_file:
35+
encoded_image = _encode_image(image_file)
36+
print(encoded_image[:200])
37+
self.assertTrue(isinstance(encoded_image, str))
38+
39+
def test_analyze_image_web_live(self):
40+
from agentstack._tools.vision import analyze_image
41+
42+
if not os.environ.get('ANTHROPIC_API_KEY'):
43+
self.skipTest("ANTHROPIC_API_KEY not set")
44+
45+
image_url = "https://github.com/AgentOps-AI/AgentStack/blob/7c1bf897742cfb58f4942a2547be70a0a1bb767a/tests/fixtures/test_image.jpg?raw=true"
46+
result = analyze_image(image_url)
47+
self.assertTrue(isinstance(result, str))
48+
49+
def test_analyze_image_local_live(self):
50+
from agentstack._tools.vision import analyze_image
51+
52+
if not os.environ.get('ANTHROPIC_API_KEY'):
53+
self.skipTest("ANTHROPIC_API_KEY not set")
54+
55+
result = analyze_image(str(TEST_IMAGE_PATH))
56+
self.assertTrue(isinstance(result, str))

0 commit comments

Comments
 (0)