|
| 1 | +/* |
| 2 | + * Copyright OpenSearch Contributors |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +package org.opensearch.ml.rest; |
| 7 | + |
| 8 | +import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD; |
| 9 | + |
| 10 | +import java.io.IOException; |
| 11 | +import java.util.List; |
| 12 | +import java.util.Map; |
| 13 | + |
| 14 | +import org.apache.hc.core5.http.HttpHeaders; |
| 15 | +import org.apache.hc.core5.http.io.entity.StringEntity; |
| 16 | +import org.apache.hc.core5.http.message.BasicHeader; |
| 17 | +import org.junit.After; |
| 18 | +import org.junit.Before; |
| 19 | +import org.junit.Test; |
| 20 | +import org.opensearch.client.Response; |
| 21 | +import org.opensearch.ml.common.agent.MLAgent; |
| 22 | +import org.opensearch.ml.common.agent.MLToolSpec; |
| 23 | +import org.opensearch.ml.utils.TestHelper; |
| 24 | + |
| 25 | +public class RestQueryPlanningToolIT extends MLCommonsRestTestCase { |
| 26 | + |
| 27 | + private static final String IRIS_INDEX = "iris_data"; |
| 28 | + private String queryPlanningModelId; |
| 29 | + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); |
| 30 | + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); |
| 31 | + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); |
| 32 | + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; |
| 33 | + |
| 34 | + private final String bedrockClaudeModelConnectorEntity = "{\n" |
| 35 | + + " \"name\": \"Amazon Bedrock Claude 3.7-sonnet connector\",\n" |
| 36 | + + " \"description\": \"connector for base agent with tools\",\n" |
| 37 | + + " \"version\": 1,\n" |
| 38 | + + " \"protocol\": \"aws_sigv4\",\n" |
| 39 | + + " \"parameters\": {\n" |
| 40 | + + " \"region\": \"" |
| 41 | + + GITHUB_CI_AWS_REGION |
| 42 | + + "\",\n" |
| 43 | + + " \"service_name\": \"bedrock\",\n" |
| 44 | + + " \"model\": \"us.anthropic.claude-3-7-sonnet-20250219-v1:0\",\n" |
| 45 | + + " \"system_prompt\":\"please help answer the user question. \"\n" |
| 46 | + + " },\n" |
| 47 | + + " \"credential\": {\n" |
| 48 | + + " \"access_key\":\" " |
| 49 | + + AWS_ACCESS_KEY_ID |
| 50 | + + "\",\n" |
| 51 | + + " \"secret_key\": \"" |
| 52 | + + AWS_SECRET_ACCESS_KEY |
| 53 | + + "\",\n" |
| 54 | + + " \"session_token\": \"" |
| 55 | + + AWS_SESSION_TOKEN |
| 56 | + + "\"\n" |
| 57 | + + " },\n" |
| 58 | + + " \"actions\": [\n" |
| 59 | + + " {\n" |
| 60 | + + " \"action_type\": \"predict\",\n" |
| 61 | + + " \"method\": \"POST\",\n" |
| 62 | + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/converse\",\n" |
| 63 | + + " \"headers\": {\n" |
| 64 | + + " \"content-type\": \"application/json\"\n" |
| 65 | + + " },\n" |
| 66 | + + " \"request_body\": \"{ \\\"system\\\": [{\\\"text\\\": \\\"${parameters.system_prompt}\\\"}], \\\"messages\\\": [{\\\"role\\\":\\\"user\\\",\\\"content\\\":[{\\\"text\\\":\\\"${parameters.prompt}\\\"}]}]}\"\n" |
| 67 | + + " }\n" |
| 68 | + + " ]\n" |
| 69 | + + "}"; |
| 70 | + |
| 71 | + @Before |
| 72 | + public void setup() throws IOException, InterruptedException { |
| 73 | + ingestIrisIndexData(); |
| 74 | + queryPlanningModelId = registerQueryPlanningModel(); |
| 75 | + } |
| 76 | + |
| 77 | + @After |
| 78 | + public void deleteIndices() throws IOException { |
| 79 | + deleteIndex(IRIS_INDEX); |
| 80 | + } |
| 81 | + |
| 82 | + @Test |
| 83 | + public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException { |
| 84 | + String agentName = "Test_QueryPlanningAgent_DefaultPrompt"; |
| 85 | + String agentId = registerAgentWithQueryPlanningTool(agentName, queryPlanningModelId); |
| 86 | + assertNotNull(agentId); |
| 87 | + |
| 88 | + String query = "{\"parameters\": {\"query_text\": \"How many iris flowers of type setosa are there?\"}}"; |
| 89 | + Response response = executeAgent(agentId, query); |
| 90 | + String responseBody = TestHelper.httpEntityToString(response.getEntity()); |
| 91 | + |
| 92 | + Map<String, Object> responseMap = gson.fromJson(responseBody, Map.class); |
| 93 | + |
| 94 | + List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results"); |
| 95 | + Map<String, Object> firstResult = inferenceResults.get(0); |
| 96 | + List<Map<String, Object>> outputArray = (List<Map<String, Object>>) firstResult.get("output"); |
| 97 | + Map<String, Object> output = (Map<String, Object>) outputArray.get(0); |
| 98 | + String result = output.get("result").toString(); |
| 99 | + |
| 100 | + assertTrue(result.contains("query")); |
| 101 | + deleteAgent(agentId); |
| 102 | + } |
| 103 | + |
| 104 | + private String registerAgentWithQueryPlanningTool(String agentName, String modelId) throws IOException { |
| 105 | + MLToolSpec listIndexTool = MLToolSpec |
| 106 | + .builder() |
| 107 | + .type("ListIndexTool") |
| 108 | + .name("MyListIndexTool") |
| 109 | + .description("A tool for list indices") |
| 110 | + .parameters(Map.of("index", IRIS_INDEX, "question", "what fields are in the index?")) |
| 111 | + .includeOutputInAgentResponse(true) |
| 112 | + .build(); |
| 113 | + |
| 114 | + MLToolSpec queryPlanningTool = MLToolSpec |
| 115 | + .builder() |
| 116 | + .type("QueryPlanningTool") |
| 117 | + .name("MyQueryPlanningTool") |
| 118 | + .description("A tool for planning queries") |
| 119 | + .parameters(Map.of(MODEL_ID_FIELD, modelId)) |
| 120 | + .includeOutputInAgentResponse(true) |
| 121 | + .build(); |
| 122 | + |
| 123 | + MLAgent agent = MLAgent |
| 124 | + .builder() |
| 125 | + .name(agentName) |
| 126 | + .type("flow") |
| 127 | + .description("Test agent with QueryPlanningTool") |
| 128 | + .tools(List.of(listIndexTool, queryPlanningTool)) |
| 129 | + .build(); |
| 130 | + |
| 131 | + return registerAgent(agentName, agent); |
| 132 | + } |
| 133 | + |
| 134 | + private String registerQueryPlanningModel() throws IOException, InterruptedException { |
| 135 | + String bedrockClaudeModelName = "bedrock claude model " + randomAlphaOfLength(5); |
| 136 | + return registerRemoteModel(bedrockClaudeModelConnectorEntity, bedrockClaudeModelName, true); |
| 137 | + } |
| 138 | + |
| 139 | + private void ingestIrisIndexData() throws IOException { |
| 140 | + String bulkRequestBody = "{\"index\":{\"_index\":\"" |
| 141 | + + IRIS_INDEX |
| 142 | + + "\",\"_id\":\"1\"}}\n" |
| 143 | + + "{\"petal_length_in_cm\":1.4,\"petal_width_in_cm\":0.2,\"sepal_length_in_cm\":5.1,\"sepal_width_in_cm\":3.5,\"species\":\"setosa\"}\n" |
| 144 | + + "{\"index\":{\"_index\":\"" |
| 145 | + + IRIS_INDEX |
| 146 | + + "\",\"_id\":\"2\"}}\n" |
| 147 | + + "{\"petal_length_in_cm\":4.5,\"petal_width_in_cm\":1.5,\"sepal_length_in_cm\":6.4,\"sepal_width_in_cm\":2.9,\"species\":\"versicolor\"}\n"; |
| 148 | + TestHelper |
| 149 | + .makeRequest( |
| 150 | + client(), |
| 151 | + "POST", |
| 152 | + "/_bulk", |
| 153 | + null, |
| 154 | + new StringEntity(bulkRequestBody), |
| 155 | + List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json")) |
| 156 | + ); |
| 157 | + TestHelper.makeRequest(client(), "POST", "/" + IRIS_INDEX + "/_refresh", null, "", List.of()); |
| 158 | + } |
| 159 | + |
| 160 | + private String registerAgent(String agentName, MLAgent agent) throws IOException { |
| 161 | + Response response = TestHelper |
| 162 | + .makeRequest( |
| 163 | + client(), |
| 164 | + "POST", |
| 165 | + "/_plugins/_ml/agents/_register", |
| 166 | + null, |
| 167 | + new StringEntity(gson.toJson(agent)), |
| 168 | + List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json")) |
| 169 | + ); |
| 170 | + Map<String, String> responseMap = gson.fromJson(TestHelper.httpEntityToString(response.getEntity()), Map.class); |
| 171 | + return responseMap.get("agent_id"); |
| 172 | + } |
| 173 | + |
| 174 | + private Response executeAgent(String agentId, String query) throws IOException { |
| 175 | + return TestHelper |
| 176 | + .makeRequest( |
| 177 | + client(), |
| 178 | + "POST", |
| 179 | + "/_plugins/_ml/agents/" + agentId + "/_execute", |
| 180 | + null, |
| 181 | + new StringEntity(query), |
| 182 | + List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json")) |
| 183 | + ); |
| 184 | + } |
| 185 | + |
| 186 | + private void deleteAgent(String agentId) throws IOException { |
| 187 | + TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of()); |
| 188 | + } |
| 189 | + |
| 190 | + public String registerModel(String modelContent) throws IOException { |
| 191 | + Response response = TestHelper |
| 192 | + .makeRequest( |
| 193 | + client(), |
| 194 | + "POST", |
| 195 | + "/_plugins/_ml/models/_register", |
| 196 | + null, |
| 197 | + new StringEntity(modelContent), |
| 198 | + List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json")) |
| 199 | + ); |
| 200 | + Map<String, String> responseMap = gson.fromJson(TestHelper.httpEntityToString(response.getEntity()), Map.class); |
| 201 | + return responseMap.get("task_id"); |
| 202 | + } |
| 203 | + |
| 204 | +} |
0 commit comments