Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
Expand All @@ -28,15 +26,13 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner;
import org.opensearch.ml.engine.function_calling.FunctionCallingFactory;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.engine.tools.QueryPlanningTool;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.PutDataObjectRequest;
Expand Down Expand Up @@ -91,9 +87,6 @@ private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse
return;
}

// Update QueryPlanningTool to include model_id if missing
List<MLToolSpec> updatedTools = processQueryPlannerTools(agent);

String llmInterface = (agent.getParameters() != null) ? agent.getParameters().get(LLM_INTERFACE) : null;
if (llmInterface != null) {
if (llmInterface.trim().isEmpty()) {
Expand All @@ -111,7 +104,7 @@ private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse

Instant now = Instant.now();
boolean isHiddenAgent = RestActionUtils.isSuperAdminUser(clusterService, client);
MLAgent mlAgent = agent.toBuilder().tools(updatedTools).createdTime(now).lastUpdateTime(now).isHidden(isHiddenAgent).build();
MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).isHidden(isHiddenAgent).build();
String tenantId = agent.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, listener)) {
return;
Expand All @@ -131,25 +124,6 @@ private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse
}
}

private List<MLToolSpec> processQueryPlannerTools(MLAgent agent) {
List<MLToolSpec> tools = agent.getTools();
List<MLToolSpec> updatedTools = tools;
if (tools != null) {
// Update QueryPlanningTool with model_id if missing and LLM exists
if (agent.getLlm() != null && agent.getLlm().getModelId() != null && !agent.getLlm().getModelId().isBlank()) {
updatedTools = tools.stream().map(tool -> {
if (tool.getType().equals(QueryPlanningTool.TYPE)) {
Map<String, String> params = tool.getParameters() != null ? new HashMap<>(tool.getParameters()) : new HashMap<>();
params.putIfAbsent("model_id", agent.getLlm().getModelId());
return tool.toBuilder().parameters(params).build();
}
return tool;
}).collect(Collectors.toList());
}
}
return updatedTools;
}

private void createConversationAgent(MLAgent planExecuteReflectAgent, String tenantId, ActionListener<String> listener) {
Instant now = Instant.now();
boolean isHiddenAgent = RestActionUtils.isSuperAdminUser(clusterService, client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand All @@ -31,7 +30,6 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchException;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -45,12 +43,10 @@
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.engine.tools.QueryPlanningTool;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
import org.opensearch.tasks.Task;
Expand Down Expand Up @@ -374,113 +370,6 @@ public void test_execute_registerAgent_PlanExecuteAndReflect_WithExecutorAgentId
verify(client, times(1)).index(any(), any());
}

@Test
public void test_execute_registerAgent_QueryPlanningTool_addsModelId_whenMissing() {
// Create QueryPlanningTool without model_id parameter
MLToolSpec queryPlanningTool = new MLToolSpec(
QueryPlanningTool.TYPE,
"QueryPlanningTool",
"QueryPlanningTool",
Collections.emptyMap(), // No parameters
Collections.emptyMap(),
false,
Collections.emptyMap(),
null,
null
);

MLAgent mlAgent = MLAgent
.builder()
.name("agent")
.type(MLAgentType.CONVERSATIONAL.name())
.description("description")
.llm(new LLMSpec("test_model_id", new HashMap<>()))
.tools(List.of(queryPlanningTool))
.build();

MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class);
when(request.getMlAgent()).thenReturn(mlAgent);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLAgentIndex(any());

doAnswer(invocation -> {
ActionListener<IndexResponse> al = invocation.getArgument(1);
al.onResponse(indexResponse);
return null;
}).when(client).index(any(), any());

transportRegisterAgentAction.doExecute(task, request, actionListener);

// Verify that the agent was indexed with updated tools containing model_id
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
verify(client).index(indexRequestCaptor.capture(), any());

IndexRequest indexRequest = indexRequestCaptor.getValue();
assertNotNull(indexRequest);
String source = indexRequest.source().utf8ToString();
assertTrue("Agent source should contain model_id", source.contains("\"model_id\":\"test_model_id\""));
}

@Test
public void test_execute_registerAgent_QueryPlanningTool_preservesExistingModelId() {
// Create QueryPlanningTool with existing model_id parameter
Map<String, String> existingParams = new HashMap<>();
existingParams.put("model_id", "existing_model_id");
existingParams.put("other_param", "other_value");

MLToolSpec queryPlanningTool = new MLToolSpec(
QueryPlanningTool.TYPE,
"QueryPlanningTool",
"QueryPlanningTool",
existingParams,
Collections.emptyMap(),
false,
Collections.emptyMap(),
null,
null
);

MLAgent mlAgent = MLAgent
.builder()
.name("agent")
.type(MLAgentType.CONVERSATIONAL.name())
.description("description")
.llm(new LLMSpec("new_model_id", new HashMap<>()))
.tools(List.of(queryPlanningTool))
.build();

MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class);
when(request.getMlAgent()).thenReturn(mlAgent);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLAgentIndex(any());

doAnswer(invocation -> {
ActionListener<IndexResponse> al = invocation.getArgument(1);
al.onResponse(indexResponse);
return null;
}).when(client).index(any(), any());

transportRegisterAgentAction.doExecute(task, request, actionListener);

// Verify that the agent was indexed with tools preserving existing model_id
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
verify(client).index(indexRequestCaptor.capture(), any());

IndexRequest indexRequest = indexRequestCaptor.getValue();
assertNotNull(indexRequest);
String source = indexRequest.source().utf8ToString();
assertTrue("Agent source should contain existing model_id", source.contains("\"model_id\":\"existing_model_id\""));
assertTrue("Agent source should contain other_param", source.contains("\"other_param\":\"other_value\""));
}

@Test
public void test_execute_registerAgent_MCPConnectorDisabled() {
// Create an MLAgent with MCP connectors in parameters
Expand Down