Skip to content

Commit

Permalink
fix: LLM - Exported the ChatMessage class
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544541329
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 30, 2023
1 parent 459ba86 commit 7bf7634
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
30 changes: 20 additions & 10 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,16 @@ def test_chat(self):
output_text="Ned likes watching movies.",
),
],
message_history=[
preview_language_models.ChatMessage(
author=preview_language_models.ChatSession.USER_AUTHOR,
content="Question 1?",
),
preview_language_models.ChatMessage(
author=preview_language_models.ChatSession.MODEL_AUTHOR,
content="Answer 1.",
),
],
temperature=0.0,
)

Expand All @@ -773,11 +783,11 @@ def test_chat(self):
]
response = chat.send_message(message_text1)
assert response.text == expected_response1
assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message_text1
assert chat.message_history[1].author == chat.MODEL_AUTHOR
assert chat.message_history[1].content == expected_response1
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message_text1
assert chat.message_history[3].author == chat.MODEL_AUTHOR
assert chat.message_history[3].content == expected_response1

gca_predict_response2 = gca_prediction_service.PredictResponse()
gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
Expand All @@ -793,11 +803,11 @@ def test_chat(self):
]
response = chat.send_message(message_text2, temperature=0.1)
assert response.text == expected_response2
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message_text2
assert chat.message_history[3].author == chat.MODEL_AUTHOR
assert chat.message_history[3].content == expected_response2
assert len(chat.message_history) == 6
assert chat.message_history[4].author == chat.USER_AUTHOR
assert chat.message_history[4].content == message_text2
assert chat.message_history[5].author == chat.MODEL_AUTHOR
assert chat.message_history[5].content == expected_response2

# Validating the parameters
chat_temperature = 0.1
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vertexai.language_models._language_models import (
_PreviewTextEmbeddingModel,
_PreviewTextGenerationModel,
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
Expand All @@ -31,6 +32,7 @@
TextEmbeddingModel = _PreviewTextEmbeddingModel

__all__ = [
"ChatMessage",
"ChatModel",
"ChatSession",
"CodeChatModel",
Expand Down

0 comments on commit 7bf7634

Please sign in to comment.