Skip to content

Commit 57e02ae

Browse files
pgrayyjsamuel1
authored andcommitted
models - openai - b64encode method (strands-agents#260)
1 parent d8d119d commit 57e02ae

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

src/strands/types/models/openai.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@ class OpenAIModel(Model, abc.ABC):
3434

3535
config: dict[str, Any]
3636

37+
@staticmethod
38+
def b64encode(data: bytes) -> bytes:
39+
"""Base64 encode the provided data.
40+
41+
If the data is already base64 encoded, we do nothing.
42+
Note, this is a temporary method used to provide a warning to users who pass in base64 encoded data. In future
43+
versions, images and documents will be base64 encoded on behalf of customers for consistency with the other
44+
providers and general convenience.
45+
46+
Args:
47+
data: Data to encode.
48+
49+
Returns:
50+
Base64 encoded data.
51+
"""
52+
try:
53+
base64.b64decode(data, validate=True)
54+
logger.warning(
55+
"issue=<%s> | base64 encoded images and documents will not be accepted in future versions",
56+
"https://github.com/strands-agents/sdk-python/issues/252",
57+
)
58+
except ValueError:
59+
data = base64.b64encode(data)
60+
61+
return data
62+
3763
@classmethod
3864
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
3965
"""Format an OpenAI compatible content block.
@@ -60,17 +86,8 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
6086

6187
if "image" in content:
6288
mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
63-
image_bytes = content["image"]["source"]["bytes"]
64-
try:
65-
base64.b64decode(image_bytes, validate=True)
66-
logger.warning(
67-
"issue=<%s> | base64 encoded images will not be accepted in a future version",
68-
"https://github.com/strands-agents/sdk-python/issues/252",
69-
)
70-
except ValueError:
71-
image_bytes = base64.b64encode(image_bytes)
72-
73-
image_data = image_bytes.decode("utf-8")
89+
image_data = OpenAIModel.b64encode(content["image"]["source"]["bytes"]).decode("utf-8")
90+
7491
return {
7592
"image_url": {
7693
"detail": "auto",

tests/strands/types/models/test_openai.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,15 @@ def test_format_chunk_unknown_type(model):
362362

363363
with pytest.raises(RuntimeError, match="chunk_type=<unknown> | unknown type"):
364364
model.format_chunk(event)
365+
366+
367+
@pytest.mark.parametrize(
368+
("data", "exp_result"),
369+
[
370+
(b"image", b"aW1hZ2U="),
371+
(b"aW1hZ2U=", b"aW1hZ2U="),
372+
],
373+
)
374+
def test_b64encode(data, exp_result):
375+
tru_result = SAOpenAIModel.b64encode(data)
376+
assert tru_result == exp_result

0 commit comments

Comments
 (0)