Skip to content

Commit 1fc5a0a

Browse files
committed
Added support for multimodal embeddings, with tests
1 parent 33a389c commit 1fc5a0a

File tree

2 files changed

+111
-11
lines changed

2 files changed

+111
-11
lines changed

packages/lmi/src/lmi/embeddings.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,37 @@
1313
from lmi.cost_tracker import track_costs
1414
from lmi.llms import PassThroughRouter
1515
from lmi.rate_limiter import GLOBAL_LIMITER
16-
from lmi.utils import get_litellm_retrying_config
16+
from lmi.utils import get_litellm_retrying_config, is_encoded_image
17+
18+
URL_ENCODED_IMAGE_TOKEN_ESTIMATE = 85 # tokens
19+
20+
21+
def estimate_tokens(
22+
document: str
23+
| list[str]
24+
| list[litellm.ChatCompletionImageObject]
25+
| list[litellm.types.llms.vertex_ai.PartType],
26+
) -> float:
27+
"""Estimate token count for rate limiting purposes."""
28+
if isinstance(document, str): # Text or a data URL
29+
return (
30+
URL_ENCODED_IMAGE_TOKEN_ESTIMATE
31+
if is_encoded_image(document)
32+
else len(document) / CHARACTERS_PER_TOKEN_ASSUMPTION
33+
)
34+
# For multimodal content, estimate based on text parts and add fixed cost for images
35+
token_count = 0.0
36+
for part in document:
37+
if isinstance(part, str): # Part of a batch of text or data URLs
38+
token_count += estimate_tokens(part)
39+
# Handle different multimodal formats
40+
elif part.get("type") == "image_url": # OpenAI format
41+
token_count += URL_ENCODED_IMAGE_TOKEN_ESTIMATE
42+
elif ( # Gemini text format -- https://ai.google.dev/api#text-only-prompt
43+
"text" in part
44+
):
45+
token_count += len(part["text"]) / CHARACTERS_PER_TOKEN_ASSUMPTION # type: ignore[typeddict-item]
46+
return token_count
1747

1848

1949
class EmbeddingModes(StrEnum):
@@ -39,7 +69,7 @@ def set_mode(self, mode: EmbeddingModes) -> None:
3969

4070
@abstractmethod
4171
async def embed_documents(self, texts: list[str]) -> list[list[float]]:
42-
pass
72+
"""Embed a list of documents."""
4373

4474
async def embed_document(self, text: str) -> list[float]:
4575
return (await self.embed_documents([text]))[0]
@@ -138,7 +168,7 @@ def _truncate_if_large(self, texts: list[str]) -> list[str]:
138168
# heuristic about ratio of tokens to characters
139169
conservative_char_token_ratio = 3
140170
maybe_too_large = max_tokens * conservative_char_token_ratio
141-
if any(len(t) > maybe_too_large for t in texts):
171+
if any(len(t) > maybe_too_large for t in texts if not is_encoded_image(t)):
142172
try:
143173
enct = tiktoken.encoding_for_model("cl100k_base")
144174
enc_batch = enct.encode_ordinary_batch(texts)
@@ -154,16 +184,12 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]:
154184
N = len(texts)
155185
embeddings = []
156186
for i in range(0, N, batch_size):
157-
await self.check_rate_limit(
158-
sum(
159-
len(t) / CHARACTERS_PER_TOKEN_ASSUMPTION
160-
for t in texts[i : i + batch_size]
161-
)
162-
)
187+
batch = texts[i : i + batch_size]
188+
await self.check_rate_limit(sum(estimate_tokens(t) for t in batch))
163189

164190
response = await track_costs(self.router.aembedding)(
165191
model=self.name,
166-
input=texts[i : i + batch_size],
192+
input=batch,
167193
dimensions=self.ndim,
168194
**self.config.get("kwargs", {}),
169195
)

packages/lmi/tests/test_embeddings.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import litellm
66
import pytest
7+
import tiktoken
78
from litellm.caching import Cache, InMemoryCache
89
from pytest_subtests import SubTests
910

@@ -15,8 +16,34 @@
1516
SentenceTransformerEmbeddingModel,
1617
SparseEmbeddingModel,
1718
embedding_model_factory,
19+
estimate_tokens,
1820
)
19-
from lmi.utils import VCR_DEFAULT_MATCH_ON
21+
from lmi.utils import VCR_DEFAULT_MATCH_ON, encode_image_as_url
22+
23+
24+
def test_estimate_tokens(subtests: SubTests, png_image: bytes) -> None:
25+
with subtests.test(msg="text only"):
26+
text_only = "Hello world"
27+
text_only_estimated_token_count = estimate_tokens(text_only)
28+
assert text_only_estimated_token_count == 2.75, (
29+
"Expected a reasonable token estimate"
30+
)
31+
text_only_actual_token_count = len(
32+
tiktoken.get_encoding("cl100k_base").encode(text_only)
33+
)
34+
assert text_only_estimated_token_count == pytest.approx(
35+
text_only_actual_token_count, abs=1
36+
), "Estimation should be within one token of what tiktoken"
37+
38+
# Test multimodal (text + image)
39+
with subtests.test(msg="multimodal"): # Text + image
40+
multimodal = [
41+
"What is in this image?",
42+
encode_image_as_url(image_type="png", image_data=png_image),
43+
]
44+
assert estimate_tokens(multimodal) == 90.5, (
45+
"Expected a reasonable token estimate"
46+
)
2047

2148

2249
class TestLiteLLMEmbeddingModel:
@@ -231,6 +258,53 @@ async def test_router_usage(
231258
# Confirm use of the sentinel timeout in the Router's model_list or pass through
232259
assert mock_aembedding.call_args.kwargs["timeout"] == self.SENTINEL_TIMEOUT
233260

261+
@pytest.mark.asyncio
262+
async def test_multimodal_embedding(
263+
self, subtests: SubTests, png_image_gcs: str
264+
) -> None:
265+
multimodal_model = LiteLLMEmbeddingModel(
266+
name=f"{litellm.LlmProviders.VERTEX_AI.value}/multimodalembedding@001"
267+
)
268+
269+
with subtests.test(msg="text or image only"):
270+
embedding_text_only = await multimodal_model.embed_document("Some text")
271+
assert len(embedding_text_only) == 1408
272+
assert all(isinstance(x, float) for x in embedding_text_only)
273+
274+
embedding_image_only = await multimodal_model.embed_document(png_image_gcs)
275+
assert len(embedding_image_only) == 1408
276+
assert all(isinstance(x, float) for x in embedding_image_only)
277+
278+
assert embedding_image_only != embedding_text_only
279+
280+
with subtests.test(msg="text and image mixing"):
281+
(embedding_image_text,) = await multimodal_model.embed_documents([
282+
"What is in this image?",
283+
png_image_gcs,
284+
])
285+
assert len(embedding_image_text) == 1408
286+
assert all(isinstance(x, float) for x in embedding_image_text)
287+
288+
(embedding_two_images,) = await multimodal_model.embed_documents([
289+
png_image_gcs,
290+
png_image_gcs,
291+
])
292+
assert len(embedding_two_images) == 1408
293+
assert all(isinstance(x, float) for x in embedding_two_images)
294+
295+
assert embedding_image_text != embedding_two_images
296+
297+
with subtests.test(msg="batching"):
298+
multimodal_model.config["batch_size"] = 1
299+
embeddings = await multimodal_model.embed_documents([
300+
"Some text",
301+
png_image_gcs,
302+
])
303+
assert len(embeddings) == 2
304+
for embedding in embeddings:
305+
assert len(embedding) == 1408
306+
assert all(isinstance(x, float) for x in embedding)
307+
234308

235309
@pytest.mark.asyncio
236310
async def test_sparse_embedding_model(subtests: SubTests):

0 commit comments

Comments
 (0)