From 7bf7634e97dfe56c3130264eeb62a9b5d6b55cac Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Thu, 29 Jun 2023 21:03:34 -0700 Subject: [PATCH] fix: LLM - Exported the `ChatMessage` class PiperOrigin-RevId: 544541329 --- tests/unit/aiplatform/test_language_models.py | 30 ++++++++++++------- vertexai/preview/language_models.py | 2 ++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index f7bf9a9df6..5c478a637c 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -756,6 +756,16 @@ def test_chat(self): output_text="Ned likes watching movies.", ), ], + message_history=[ + preview_language_models.ChatMessage( + author=preview_language_models.ChatSession.USER_AUTHOR, + content="Question 1?", + ), + preview_language_models.ChatMessage( + author=preview_language_models.ChatSession.MODEL_AUTHOR, + content="Answer 1.", + ), + ], temperature=0.0, ) @@ -773,11 +783,11 @@ def test_chat(self): ] response = chat.send_message(message_text1) assert response.text == expected_response1 - assert len(chat.message_history) == 2 - assert chat.message_history[0].author == chat.USER_AUTHOR - assert chat.message_history[0].content == message_text1 - assert chat.message_history[1].author == chat.MODEL_AUTHOR - assert chat.message_history[1].content == expected_response1 + assert len(chat.message_history) == 4 + assert chat.message_history[2].author == chat.USER_AUTHOR + assert chat.message_history[2].content == message_text1 + assert chat.message_history[3].author == chat.MODEL_AUTHOR + assert chat.message_history[3].content == expected_response1 gca_predict_response2 = gca_prediction_service.PredictResponse() gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2) @@ -793,11 +803,11 @@ def test_chat(self): ] response = chat.send_message(message_text2, temperature=0.1) assert response.text == expected_response2 - assert len(chat.message_history) == 4 - assert chat.message_history[2].author == chat.USER_AUTHOR - assert chat.message_history[2].content == message_text2 - assert chat.message_history[3].author == chat.MODEL_AUTHOR - assert chat.message_history[3].content == expected_response2 + assert len(chat.message_history) == 6 + assert chat.message_history[4].author == chat.USER_AUTHOR + assert chat.message_history[4].content == message_text2 + assert chat.message_history[5].author == chat.MODEL_AUTHOR + assert chat.message_history[5].content == expected_response2 # Validating the parameters chat_temperature = 0.1 diff --git a/vertexai/preview/language_models.py b/vertexai/preview/language_models.py index 43447a8e50..ae41214b10 100644 --- a/vertexai/preview/language_models.py +++ b/vertexai/preview/language_models.py @@ -17,6 +17,7 @@ from vertexai.language_models._language_models import ( _PreviewTextEmbeddingModel, _PreviewTextGenerationModel, + ChatMessage, ChatModel, ChatSession, CodeChatModel, @@ -31,6 +32,7 @@ TextEmbeddingModel = _PreviewTextEmbeddingModel __all__ = [ + "ChatMessage", "ChatModel", "ChatSession", "CodeChatModel",