Skip to content

Commit

Permalink
feat: Make count_tokens generally-available at TextEmbeddingModel.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654133506
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jul 19, 2024
1 parent e5d087f commit efb8413
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
43 changes: 42 additions & 1 deletion tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4526,7 +4526,48 @@ def test_text_embedding(self):
== expected_embedding["statistics"]["truncated"]
)

def test_text_embedding_preview_count_tokens(self):
def test_text_embedding_count_tokens_ga(self):
"""Tests the text embedding model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextEmbeddingModel.from_pretrained(
"textembedding-gecko@001"
)

gca_count_tokens_response = (
gca_prediction_service_v1beta1.CountTokensResponse(
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
"total_billable_characters"
],
)
)

with mock.patch.object(
target=prediction_service_client_v1beta1.PredictionServiceClient,
attribute="count_tokens",
return_value=gca_count_tokens_response,
):
response = model.count_tokens(["What is life?"])

assert (
response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
)
assert (
response.total_billable_characters
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
)

def test_text_embedding_count_tokens_preview(self):
"""Tests the text embedding model."""
aiplatform.init(
project=_TEST_PROJECT,
Expand Down
6 changes: 5 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,7 +2417,11 @@ class _TunableTextEmbeddingModelMixin(_PreviewTunableTextEmbeddingModelMixin):
pass


class TextEmbeddingModel(_TextEmbeddingModel, _TunableTextEmbeddingModelMixin):
class TextEmbeddingModel(
_TextEmbeddingModel,
_TunableTextEmbeddingModelMixin,
_CountTokensMixin,
):
__module__ = "vertexai.language_models"


Expand Down

0 comments on commit efb8413

Please sign in to comment.