|  | 
| 4 | 4 | 
 | 
| 5 | 5 | import litellm | 
| 6 | 6 | import pytest | 
|  | 7 | +import tiktoken | 
| 7 | 8 | from litellm.caching import Cache, InMemoryCache | 
| 8 | 9 | from pytest_subtests import SubTests | 
| 9 | 10 | 
 | 
|  | 
| 15 | 16 |     SentenceTransformerEmbeddingModel, | 
| 16 | 17 |     SparseEmbeddingModel, | 
| 17 | 18 |     embedding_model_factory, | 
|  | 19 | +    estimate_tokens, | 
| 18 | 20 | ) | 
| 19 |  | -from lmi.utils import VCR_DEFAULT_MATCH_ON | 
|  | 21 | +from lmi.utils import VCR_DEFAULT_MATCH_ON, encode_image_as_url | 
|  | 22 | + | 
|  | 23 | + | 
|  | 24 | +def test_estimate_tokens(subtests: SubTests, png_image: bytes) -> None: | 
|  | 25 | +    with subtests.test(msg="text only"): | 
|  | 26 | +        text_only = "Hello world" | 
|  | 27 | +        text_only_estimated_token_count = estimate_tokens(text_only) | 
|  | 28 | +        assert text_only_estimated_token_count == 2.75, ( | 
|  | 29 | +            "Expected a reasonable token estimate" | 
|  | 30 | +        ) | 
|  | 31 | +        text_only_actual_token_count = len( | 
|  | 32 | +            tiktoken.get_encoding("cl100k_base").encode(text_only) | 
|  | 33 | +        ) | 
|  | 34 | +        assert text_only_estimated_token_count == pytest.approx( | 
|  | 35 | +            text_only_actual_token_count, abs=1 | 
|  | 36 | +        ), "Estimation should be within one token of what tiktoken" | 
|  | 37 | + | 
|  | 38 | +    # Test multimodal (text + image) | 
|  | 39 | +    with subtests.test(msg="multimodal"):  # Text + image | 
|  | 40 | +        multimodal = [ | 
|  | 41 | +            "What is in this image?", | 
|  | 42 | +            encode_image_as_url(image_type="png", image_data=png_image), | 
|  | 43 | +        ] | 
|  | 44 | +        assert estimate_tokens(multimodal) == 90.5, ( | 
|  | 45 | +            "Expected a reasonable token estimate" | 
|  | 46 | +        ) | 
| 20 | 47 | 
 | 
| 21 | 48 | 
 | 
| 22 | 49 | class TestLiteLLMEmbeddingModel: | 
| @@ -231,6 +258,53 @@ async def test_router_usage( | 
| 231 | 258 |         # Confirm use of the sentinel timeout in the Router's model_list or pass through | 
| 232 | 259 |         assert mock_aembedding.call_args.kwargs["timeout"] == self.SENTINEL_TIMEOUT | 
| 233 | 260 | 
 | 
|  | 261 | +    @pytest.mark.asyncio | 
|  | 262 | +    async def test_multimodal_embedding( | 
|  | 263 | +        self, subtests: SubTests, png_image_gcs: str | 
|  | 264 | +    ) -> None: | 
|  | 265 | +        multimodal_model = LiteLLMEmbeddingModel( | 
|  | 266 | +            name=f"{litellm.LlmProviders.VERTEX_AI.value}/multimodalembedding@001" | 
|  | 267 | +        ) | 
|  | 268 | + | 
|  | 269 | +        with subtests.test(msg="text or image only"): | 
|  | 270 | +            embedding_text_only = await multimodal_model.embed_document("Some text") | 
|  | 271 | +            assert len(embedding_text_only) == 1408 | 
|  | 272 | +            assert all(isinstance(x, float) for x in embedding_text_only) | 
|  | 273 | + | 
|  | 274 | +            embedding_image_only = await multimodal_model.embed_document(png_image_gcs) | 
|  | 275 | +            assert len(embedding_image_only) == 1408 | 
|  | 276 | +            assert all(isinstance(x, float) for x in embedding_image_only) | 
|  | 277 | + | 
|  | 278 | +            assert embedding_image_only != embedding_text_only | 
|  | 279 | + | 
|  | 280 | +        with subtests.test(msg="text and image mixing"): | 
|  | 281 | +            (embedding_image_text,) = await multimodal_model.embed_documents([ | 
|  | 282 | +                "What is in this image?", | 
|  | 283 | +                png_image_gcs, | 
|  | 284 | +            ]) | 
|  | 285 | +            assert len(embedding_image_text) == 1408 | 
|  | 286 | +            assert all(isinstance(x, float) for x in embedding_image_text) | 
|  | 287 | + | 
|  | 288 | +            (embedding_two_images,) = await multimodal_model.embed_documents([ | 
|  | 289 | +                png_image_gcs, | 
|  | 290 | +                png_image_gcs, | 
|  | 291 | +            ]) | 
|  | 292 | +            assert len(embedding_two_images) == 1408 | 
|  | 293 | +            assert all(isinstance(x, float) for x in embedding_two_images) | 
|  | 294 | + | 
|  | 295 | +            assert embedding_image_text != embedding_two_images | 
|  | 296 | + | 
|  | 297 | +        with subtests.test(msg="batching"): | 
|  | 298 | +            multimodal_model.config["batch_size"] = 1 | 
|  | 299 | +            embeddings = await multimodal_model.embed_documents([ | 
|  | 300 | +                "Some text", | 
|  | 301 | +                png_image_gcs, | 
|  | 302 | +            ]) | 
|  | 303 | +            assert len(embeddings) == 2 | 
|  | 304 | +            for embedding in embeddings: | 
|  | 305 | +                assert len(embedding) == 1408 | 
|  | 306 | +                assert all(isinstance(x, float) for x in embedding) | 
|  | 307 | + | 
| 234 | 308 | 
 | 
| 235 | 309 | @pytest.mark.asyncio | 
| 236 | 310 | async def test_sparse_embedding_model(subtests: SubTests): | 
|  | 
0 commit comments