|
11 | 11 | import static org.mockito.Mockito.when; |
12 | 12 | import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; |
13 | 13 | import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; |
| 14 | +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE; |
| 15 | +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED; |
14 | 16 | import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED; |
15 | 17 | import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MEMORY_ID; |
16 | 18 | import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; |
|
36 | 38 | import org.mockito.Mock; |
37 | 39 | import org.mockito.Mockito; |
38 | 40 | import org.mockito.MockitoAnnotations; |
| 41 | +import org.opensearch.OpenSearchException; |
39 | 42 | import org.opensearch.ResourceNotFoundException; |
40 | 43 | import org.opensearch.Version; |
41 | 44 | import org.opensearch.action.get.GetRequest; |
@@ -171,7 +174,8 @@ public void setup() { |
171 | 174 | when(client.threadPool()).thenReturn(threadPool); |
172 | 175 | when(threadPool.getThreadContext()).thenReturn(threadContext); |
173 | 176 | when(this.clusterService.getSettings()).thenReturn(settings); |
174 | | - when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED))); |
| 177 | + when(this.clusterService.getClusterSettings()) |
| 178 | + .thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED, ML_COMMONS_AGENTIC_SEARCH_ENABLED))); |
175 | 179 |
|
176 | 180 | settings = Settings.builder().build(); |
177 | 181 | mlAgentExecutor = Mockito |
@@ -774,6 +778,67 @@ public void test_AsyncMode_IndexTask_failure() throws IOException { |
774 | 778 | Assert.assertNotNull(exceptionCaptor.getValue()); |
775 | 779 | } |
776 | 780 |
|
| 781 | + @Test |
| 782 | + public void test_query_planning_requires_agentic_search_enabled() throws IOException { |
| 783 | + // Create an MLAgent with QueryPlanningTool |
| 784 | + MLAgent mlAgentWithQueryPlanning = new MLAgent( |
| 785 | + "test", |
| 786 | + MLAgentType.FLOW.name(), |
| 787 | + "test", |
| 788 | + new LLMSpec("test_model", Map.of("test_key", "test_value")), |
| 789 | + List |
| 790 | + .of( |
| 791 | + new MLToolSpec( |
| 792 | + "QueryPlanningTool", |
| 793 | + "QueryPlanningTool", |
| 794 | + "QueryPlanningTool", |
| 795 | + Collections.emptyMap(), |
| 796 | + Collections.emptyMap(), |
| 797 | + false, |
| 798 | + Collections.emptyMap(), |
| 799 | + null, |
| 800 | + null |
| 801 | + ) |
| 802 | + ), |
| 803 | + Map.of("test", "test"), |
| 804 | + new MLMemorySpec("memoryType", "123", 0), |
| 805 | + Instant.EPOCH, |
| 806 | + Instant.EPOCH, |
| 807 | + "test", |
| 808 | + false, |
| 809 | + null |
| 810 | + ); |
| 811 | + |
| 812 | + // Create GetResponse with the MLAgent that has QueryPlanningTool |
| 813 | + XContentBuilder content = mlAgentWithQueryPlanning.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); |
| 814 | + BytesReference bytesReference = BytesReference.bytes(content); |
| 815 | + GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null); |
| 816 | + GetResponse agentGetResponse = new GetResponse(getResult); |
| 817 | + |
| 818 | + // Create a new MLAgentExecutor with agentic search disabled |
| 819 | + MLAgentExecutor mlAgentExecutorWithDisabledSearch = Mockito |
| 820 | + .spy(new MLAgentExecutor(client, sdkClient, settings, clusterService, xContentRegistry, toolFactories, memoryMap, false, null)); |
| 821 | + |
| 822 | + // Mock the agent get response |
| 823 | + Mockito.doAnswer(invocation -> { |
| 824 | + ActionListener<GetResponse> listener = invocation.getArgument(1); |
| 825 | + listener.onResponse(agentGetResponse); |
| 826 | + return null; |
| 827 | + }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); |
| 828 | + |
| 829 | + // Mock the agent runner |
| 830 | + Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithDisabledSearch).getAgentRunner(Mockito.any()); |
| 831 | + |
| 832 | + // Execute the agent |
| 833 | + mlAgentExecutorWithDisabledSearch.execute(getAgentMLInput(), agentActionListener); |
| 834 | + |
| 835 | + // Verify that the execution fails with the correct error message |
| 836 | + Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); |
| 837 | + Exception exception = exceptionCaptor.getValue(); |
| 838 | + Assert.assertTrue(exception instanceof OpenSearchException); |
| 839 | + Assert.assertEquals(exception.getMessage(), ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE); |
| 840 | + } |
| 841 | + |
777 | 842 | private AgentMLInput getAgentMLInput() { |
778 | 843 | Map<String, String> params = new HashMap<>(); |
779 | 844 | params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); |
|
0 commit comments