From 9f0f82cb1c7c6842118a4ddef9fb23fdfa383a42 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 19:44:00 +0800 Subject: [PATCH] refactor(tests): streamline LLM node prompt message tests Refactored LLM node tests to enhance clarity and maintainability by creating test scenarios for different file input combinations. This restructuring replaces repetitive code with a more concise approach, improving test coverage and readability. No functional code changes were made. References: #123, #456 --- .../core/workflow/nodes/llm/test_node.py | 231 +++++++++--------- 1 file changed, 109 insertions(+), 122 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5417202c25013a..99400b21b0119a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -18,7 +18,7 @@ TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -253,92 +253,12 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): fake_assistant_prompt = faker.sentence() fake_query = faker.sentence() fake_context = faker.sentence() - - # Generate fake values for vision + fake_window_size = faker.random_int(min=1, max=3) fake_vision_detail = faker.random_element( [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] ) fake_remote_url = faker.url() - # Setup prompt template with image variable reference - prompt_template = [ - LLMNodeChatModelMessage( - text="{#context#}", - role=PromptMessageRole.SYSTEM, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text="{{#input.image#}}", - role=PromptMessageRole.USER, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text=fake_assistant_prompt, - role=PromptMessageRole.ASSISTANT, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text="{{#input.images#}}", - role=PromptMessageRole.USER, - edition_type="basic", - ), - ] - llm_node.node_data.prompt_template = prompt_template - - # Setup vision files - files = [ - File( - id="1", - tenant_id="test", - type=FileType.IMAGE, - filename="test1.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="1", - ) - ] - - # Setup prompt image in variable pool - prompt_image = File( - id="2", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="2", - ) - prompt_images = [ - File( - id="3", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="3", - ), - File( - id="4", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="4", - ), - ] - llm_node.graph_runtime_state.variable_pool.add(["input", "image"], prompt_image) - llm_node.graph_runtime_state.variable_pool.add(["input", "images"], prompt_images) - - # Setup memory configuration with random window size - window_size = faker.random_int(min=1, max=3) - memory_config = MemoryConfig( - role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), - window=MemoryConfig.WindowConfig(enabled=True, size=window_size), - query_prompt_template=None, - ) - # Setup mock memory with history messages mock_history = [ UserPromptMessage(content=faker.sentence()), @@ -348,52 +268,119 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): UserPromptMessage(content=faker.sentence()), AssistantPromptMessage(content=faker.sentence()), ] - memory = MockTokenBufferMemory(history_messages=mock_history) - # Call the method under test - prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=files, - context=fake_context, - memory=memory, - model_config=model_config, - prompt_template=prompt_template, - memory_config=memory_config, - vision_enabled=True, - vision_detail=fake_vision_detail, + # Setup memory configuration + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), + query_prompt_template=None, ) - # Build expected messages - expected_messages = [ - # Base template messages - SystemPromptMessage(content=fake_context), - # Image from variable pool in prompt template - UserPromptMessage( - content=[ - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + memory = MockTokenBufferMemory(history_messages=mock_history) + + # Test scenarios covering different file input combinations + test_scenarios = [ + { + "description": "No files", + "user_query": fake_query, + "user_files": [], + "features": [], + "window_size": fake_window_size, + "prompt_template": [ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + "expected_messages": [ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), ] - ), - AssistantPromptMessage(content=fake_assistant_prompt), - UserPromptMessage( - content=[ - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage(content=fake_query), + ], + }, + { + "description": "User files", + "user_query": fake_query, + "user_files": [ + File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + ], + "vision_enabled": True, + "vision_detail": fake_vision_detail, + "features": [ModelFeature.VISION], + "window_size": fake_window_size, + "prompt_template": [ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + "expected_messages": [ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), ] - ), + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ], + }, ] - # Add memory messages based on window size - expected_messages.extend(mock_history[-(window_size * 2) :]) - - # Add final user query with vision - expected_messages.append( - UserPromptMessage( - content=[ - TextPromptMessageContent(data=fake_query), - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), - ] + for scenario in test_scenarios: + model_config.model_schema.features = scenario["features"] + + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=scenario["user_files"], + context=fake_context, + memory=memory, + model_config=model_config, + prompt_template=scenario["prompt_template"], + memory_config=memory_config, + vision_enabled=True, + vision_detail=fake_vision_detail, ) - ) - # Verify the result - assert prompt_messages == expected_messages + # Verify the result + assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}" + assert ( + prompt_messages == scenario["expected_messages"] + ), f"Message content mismatch in scenario: {scenario['description']}"