Skip to content

Commit d6d5b21

Browse files
committed
support inline agent
Signed-off-by: Hailong Cui <ihailong@amazon.com>
1 parent b1795a2 commit d6d5b21

File tree

4 files changed

+323
-10
lines changed

4 files changed

+323
-10
lines changed

common/src/test/java/org/opensearch/ml/common/input/execute/agent/AgentMLInputTests.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,27 @@
66
package org.opensearch.ml.common.input.execute.agent;
77

88
import static org.junit.Assert.assertEquals;
9-
import static org.junit.Assert.assertFalse;
109
import static org.junit.Assert.assertNotNull;
1110
import static org.junit.Assert.assertNull;
1211
import static org.junit.Assert.assertTrue;
13-
import static org.mockito.ArgumentMatchers.any;
1412
import static org.mockito.Mockito.mock;
1513
import static org.mockito.Mockito.never;
1614
import static org.mockito.Mockito.verify;
1715
import static org.mockito.Mockito.when;
1816

1917
import java.io.IOException;
20-
import java.util.Collections;
2118
import java.util.HashMap;
2219
import java.util.Map;
2320

2421
import org.junit.Test;
2522
import org.opensearch.Version;
26-
import org.opensearch.common.io.stream.BytesStreamOutput;
27-
import org.opensearch.common.settings.Settings;
28-
import org.opensearch.common.xcontent.XContentType;
2923
import org.opensearch.core.common.io.stream.StreamInput;
3024
import org.opensearch.core.common.io.stream.StreamOutput;
31-
import org.opensearch.core.xcontent.NamedXContentRegistry;
3225
import org.opensearch.core.xcontent.XContentParser;
3326
import org.opensearch.ml.common.FunctionName;
3427
import org.opensearch.ml.common.agent.MLAgent;
3528
import org.opensearch.ml.common.dataset.MLInputDataset;
3629
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
37-
import org.opensearch.search.SearchModule;
3830

3931
public class AgentMLInputTests {
4032

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,14 @@ public void execute(Input input, ActionListener<Output> listener) {
163163
List<ModelTensor> modelTensors = new ArrayList<>();
164164
outputs.add(ModelTensors.builder().mlModelTensors(modelTensors).build());
165165

166-
if (mlAgentInline != null) {
166+
boolean agentIdNotSet = agentId == null || agentId.isEmpty();
167+
168+
if (agentIdNotSet && mlAgentInline != null) {
167169
prepareAndExecute(listener, mlAgentInline, inputDataSet, tenantId, isAsync, outputs, modelTensors);
168170
return;
169171
}
170172
// as we support inline agent, agent id could be empty
171-
if (agentId == null || agentId.isEmpty()) {
173+
if (agentIdNotSet) {
172174
listener.onFailure(new IllegalArgumentException("Agent id is required."));
173175
return;
174176
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,4 +818,116 @@ public GetResponse prepareMLAgent(String agentId, boolean isHidden, String tenan
818818
return new GetResponse(getResult);
819819
}
820820

821+
@Test
822+
public void test_InlineAgent_HappyCase_ReturnsResult() {
823+
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build();
824+
Mockito.doAnswer(invocation -> {
825+
ActionListener<ModelTensor> listener = invocation.getArgument(2);
826+
listener.onResponse(modelTensor);
827+
return null;
828+
}).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any());
829+
830+
Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any());
831+
832+
MLAgent inlineAgent = MLAgent.builder().name("inline_test_agent").type("flow").description("Inline agent for testing").build();
833+
834+
Map<String, String> params = new HashMap<>();
835+
params.put(MLAgentExecutor.MEMORY_ID, "memoryId");
836+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
837+
AgentMLInput agentMLInput = new AgentMLInput(null, null, FunctionName.AGENT, dataset, false, inlineAgent);
838+
839+
mlAgentExecutor.execute(agentMLInput, agentActionListener);
840+
841+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
842+
ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue();
843+
Assert.assertEquals(1, output.getMlModelOutputs().size());
844+
Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size());
845+
Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0));
846+
}
847+
848+
@Test
849+
public void test_InlineAgent_NullAgentId_WithoutInlineAgent_ThrowsException() {
850+
Map<String, String> params = new HashMap<>();
851+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
852+
AgentMLInput agentMLInput = new AgentMLInput(null, null, FunctionName.AGENT, dataset);
853+
854+
mlAgentExecutor.execute(agentMLInput, agentActionListener);
855+
856+
Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture());
857+
Exception exception = exceptionCaptor.getValue();
858+
Assert.assertTrue(exception instanceof IllegalArgumentException);
859+
Assert.assertEquals("Agent id is required.", exception.getMessage());
860+
}
861+
862+
@Test
863+
public void test_InlineAgent_EmptyAgentId_WithoutInlineAgent_ThrowsException() {
864+
Map<String, String> params = new HashMap<>();
865+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
866+
AgentMLInput agentMLInput = new AgentMLInput("", null, FunctionName.AGENT, dataset);
867+
868+
mlAgentExecutor.execute(agentMLInput, agentActionListener);
869+
870+
Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture());
871+
Exception exception = exceptionCaptor.getValue();
872+
Assert.assertTrue(exception instanceof IllegalArgumentException);
873+
Assert.assertEquals("Agent id is required.", exception.getMessage());
874+
}
875+
876+
@Test
877+
public void test_InlineAgent_AsyncMode_ReturnsTaskId() {
878+
ModelTensor modelTensor = ModelTensor.builder().name("response").result("test").build();
879+
880+
Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any());
881+
882+
MLAgent inlineAgent = MLAgent.builder().name("inline_test_agent").type("flow").build();
883+
884+
Map<String, String> params = new HashMap<>();
885+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
886+
AgentMLInput agentMLInput = new AgentMLInput(null, null, FunctionName.AGENT, dataset, true, inlineAgent);
887+
888+
agentMLInput.setIsAsync(true);
889+
890+
indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task_id", 1, 0, 2, true);
891+
doAnswer(invocation -> {
892+
ActionListener<IndexResponse> listener = invocation.getArgument(1);
893+
listener.onResponse(indexResponse);
894+
return null;
895+
}).when(client).index(any(), any());
896+
897+
mlAgentExecutor.execute(agentMLInput, agentActionListener);
898+
899+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
900+
MLTaskOutput result = (MLTaskOutput) objectCaptor.getValue();
901+
902+
Assert.assertEquals("task_id", result.getTaskId());
903+
Assert.assertEquals("RUNNING", result.getStatus());
904+
}
905+
906+
@Test
907+
public void test_InlineAgent_WithChatAgentType_ReturnsResult() {
908+
ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build();
909+
Mockito.doAnswer(invocation -> {
910+
ActionListener<ModelTensor> listener = invocation.getArgument(2);
911+
listener.onResponse(modelTensor);
912+
return null;
913+
}).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any());
914+
915+
Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any());
916+
917+
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
918+
MLAgent inlineAgent = MLAgent.builder().name("inline_chat_agent").type(MLAgentType.CONVERSATIONAL.name()).llm(llmSpec).build();
919+
920+
Map<String, String> params = new HashMap<>();
921+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
922+
AgentMLInput agentMLInput = new AgentMLInput(null, null, FunctionName.AGENT, dataset, false, inlineAgent);
923+
924+
mlAgentExecutor.execute(agentMLInput, agentActionListener);
925+
926+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
927+
ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue();
928+
Assert.assertEquals(1, output.getMlModelOutputs().size());
929+
Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size());
930+
Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0));
931+
}
932+
821933
}
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.rest;
7+
8+
import java.io.IOException;
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.Map;
12+
13+
import org.apache.hc.core5.http.HttpHeaders;
14+
import org.apache.hc.core5.http.message.BasicHeader;
15+
import org.junit.After;
16+
import org.junit.Before;
17+
import org.junit.Test;
18+
import org.opensearch.client.Response;
19+
import org.opensearch.common.settings.Settings;
20+
import org.opensearch.core.rest.RestStatus;
21+
import org.opensearch.ml.utils.TestHelper;
22+
23+
import com.google.common.collect.ImmutableList;
24+
25+
public class RestMLInlineFlowAgentIT extends MLCommonsRestTestCase {
26+
27+
private final String TEST_INDEX_NAME = "test_index";
28+
private final String TEST_INDEX_NAME2 = "test_index2";
29+
30+
private static final String INLINE_AGENT_TEMPLATE = """
31+
{
32+
"name": "test agent",
33+
"type": "flow",
34+
"description": "Inline flow agent with list index tool",
35+
"tools": [
36+
{
37+
"type": "ListIndexTool",
38+
"name": "list_indices",
39+
"description": "Tool to list all indices",
40+
"parameters": {}
41+
}
42+
]
43+
}""";
44+
45+
@Before
46+
public void setup() throws IOException, InterruptedException {
47+
createIndex(TEST_INDEX_NAME, Settings.EMPTY);
48+
createIndex(TEST_INDEX_NAME2, Settings.EMPTY);
49+
50+
List<String> dataList = new ArrayList<>();
51+
dataList.add("{\"name\":\"John Doe\",\"age\":30,\"city\":\"New York\",\"description\":\"Software Engineer\"}");
52+
dataList.add("{\"name\":\"Jane Smith\",\"age\":25,\"city\":\"Los Angeles\",\"description\":\"Data Scientist\"}");
53+
dataList.add("{\"name\":\"Bob Johnson\",\"age\":35,\"city\":\"Chicago\",\"description\":\"DevOps Engineer\"}");
54+
dataList.add("{\"name\":\"Alice Brown\",\"age\":28,\"city\":\"Seattle\",\"description\":\"Product Manager\"}");
55+
56+
for (String data : dataList) {
57+
ingestData(TEST_INDEX_NAME, data);
58+
}
59+
60+
Thread.sleep(1000);
61+
}
62+
63+
@After
64+
public void deleteIndex() throws IOException {
65+
deleteIndex(TEST_INDEX_NAME);
66+
}
67+
68+
@Test
69+
public void testInlineFlowAgentWithListIndexTool() throws IOException, InterruptedException {
70+
// Test inline flow agent with listIndexTool
71+
String requestBody = """
72+
{
73+
"agent": %s,
74+
"parameters": {}
75+
}""".formatted(INLINE_AGENT_TEMPLATE);
76+
77+
Response response = TestHelper
78+
.makeRequest(client(), "POST", "/_plugins/_ml/agents/_execute", null, TestHelper.toHttpEntity(requestBody), List.of());
79+
80+
assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode());
81+
82+
Map<String, Object> responseMap = parseResponseToMap(response);
83+
validateResponseStructure(responseMap);
84+
85+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
86+
assertNotNull(inferenceResults);
87+
assertTrue(inferenceResults.size() > 0);
88+
89+
Map<String, Object> result = inferenceResults.get(0);
90+
List<Map<String, Object>> output = (List<Map<String, Object>>) result.get("output");
91+
assertNotNull(output);
92+
assertTrue(output.size() > 0);
93+
94+
Map<String, Object> outputData = output.get(0);
95+
assertNotNull(outputData.get("result"));
96+
97+
String resultString = (String) outputData.get("result");
98+
assertTrue(resultString.contains(TEST_INDEX_NAME));
99+
assertTrue(resultString.contains(TEST_INDEX_NAME2));
100+
}
101+
102+
@Test
103+
public void testInlineFlowAgentWithListIndexToolAndUserInput() throws IOException, InterruptedException {
104+
// Test inline flow agent with listIndexTool and user input parameters
105+
String requestBody = """
106+
{
107+
"agent": %s,
108+
"parameters" : {
109+
"index": "test_index"
110+
}
111+
}""".formatted(INLINE_AGENT_TEMPLATE);
112+
113+
Response response = TestHelper
114+
.makeRequest(client(), "POST", "/_plugins/_ml/agents/_execute", null, TestHelper.toHttpEntity(requestBody), List.of());
115+
116+
assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode());
117+
118+
Map<String, Object> responseMap = parseResponseToMap(response);
119+
validateResponseStructure(responseMap);
120+
121+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
122+
assertNotNull(inferenceResults);
123+
assertTrue(inferenceResults.size() > 0);
124+
125+
Map<String, Object> result = inferenceResults.get(0);
126+
List<Map<String, Object>> output = (List<Map<String, Object>>) result.get("output");
127+
assertNotNull(output);
128+
assertTrue(output.size() > 0);
129+
130+
Map<String, Object> outputData = output.get(0);
131+
assertNotNull(outputData.get("result"));
132+
133+
String resultString = (String) outputData.get("result");
134+
assertTrue(resultString.contains("test_index"));
135+
}
136+
137+
@Test
138+
public void testInlineFlowAgentWithMultipleTools() throws IOException, InterruptedException {
139+
// Test inline flow agent with multiple tools including listIndexTool
140+
String inlineAgent = """
141+
{
142+
"name": "agent_with_multi_tools",
143+
"type": "flow",
144+
"description": "Inline flow agent with multiple tools",
145+
"tools": [
146+
{
147+
"type": "ListIndexTool",
148+
"name": "list_indices",
149+
"description": "Tool to list all indices",
150+
"parameters": {}
151+
},
152+
{
153+
"type": "SearchIndexTool",
154+
"name": "search_test_index",
155+
"description": "Tool to search test index",
156+
"parameters": {
157+
"index": "%s",
158+
"query": {
159+
"query": {
160+
"match_all": {}
161+
}
162+
}
163+
}
164+
}
165+
]
166+
}""".formatted(TEST_INDEX_NAME);
167+
168+
String requestBody = """
169+
{
170+
"agent": %s,
171+
"parameters": {}
172+
}""".formatted(inlineAgent);
173+
174+
Response response = TestHelper
175+
.makeRequest(
176+
client(),
177+
"POST",
178+
"/_plugins/_ml/agents/_execute",
179+
null,
180+
TestHelper.toHttpEntity(requestBody),
181+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
182+
);
183+
184+
assertEquals(RestStatus.OK.getStatus(), response.getStatusLine().getStatusCode());
185+
186+
Map<String, Object> responseMap = parseResponseToMap(response);
187+
validateResponseStructure(responseMap);
188+
189+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
190+
assertNotNull(inferenceResults);
191+
assertTrue(inferenceResults.size() > 0);
192+
System.out.println(gson.toJson(inferenceResults));
193+
// search index tool has been executed
194+
assertEquals("search_test_index", ((List<Map<String, Object>>) inferenceResults.getFirst().get("output")).getFirst().get("name"));
195+
}
196+
197+
private void validateResponseStructure(Map<String, Object> responseMap) {
198+
assertNotNull(responseMap);
199+
assertTrue(responseMap.containsKey("inference_results"));
200+
}
201+
202+
private void ingestData(String indexName, String data) throws IOException {
203+
Response response = TestHelper
204+
.makeRequest(client(), "POST", "/" + indexName + "/_doc", null, TestHelper.toHttpEntity(data), List.of());
205+
assertEquals(RestStatus.CREATED.getStatus(), response.getStatusLine().getStatusCode());
206+
}
207+
}

0 commit comments

Comments
 (0)