Skip to content

Commit cf64838

Browse files
Add feature flag for agentic search
Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent fb71867 commit cf64838

File tree

7 files changed

+156
-3
lines changed

7 files changed

+156
-3
lines changed

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ private MLCommonsSettings() {}
216216
public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting
217217
.boolSetting("plugins.ml_commons.memory_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
218218

219+
public static final Setting<Boolean> ML_COMMONS_AGENTIC_SEARCH_ENABLED = Setting
220+
.boolSetting("plugins.ml_commons.agentic_search_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
221+
public static final String ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE =
222+
"The Agentic Search feature is not enabled. To enable, please update the setting " + ML_COMMONS_AGENTIC_SEARCH_ENABLED.getKey();
223+
219224
public static final Setting<Boolean> ML_COMMONS_MCP_CONNECTOR_ENABLED = Setting
220225
.boolSetting("plugins.ml_commons.mcp_connector_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
221226
public static final String ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE =

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1515
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
1616
import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD;
17+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
18+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
1719
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE;
1820
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
1921
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
@@ -52,6 +54,7 @@
5254
import org.opensearch.ml.common.MLTaskType;
5355
import org.opensearch.ml.common.agent.MLAgent;
5456
import org.opensearch.ml.common.agent.MLMemorySpec;
57+
import org.opensearch.ml.common.agent.MLToolSpec;
5558
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5659
import org.opensearch.ml.common.input.Input;
5760
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
@@ -68,6 +71,7 @@
6871
import org.opensearch.ml.engine.encryptor.Encryptor;
6972
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
7073
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
74+
import org.opensearch.ml.engine.tools.QueryPlanningTool;
7175
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
7276
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
7377
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;
@@ -109,6 +113,7 @@ public class MLAgentExecutor implements Executable, SettingsChangeListener {
109113
private volatile Boolean isMultiTenancyEnabled;
110114
private Encryptor encryptor;
111115
private static volatile boolean mcpConnectorIsEnabled;
116+
private static volatile boolean agenticSearchIsEnabled;
112117

113118
public MLAgentExecutor(
114119
Client client,
@@ -132,6 +137,8 @@ public MLAgentExecutor(
132137
this.encryptor = encryptor;
133138
this.mcpConnectorIsEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(clusterService.getSettings());
134139
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> mcpConnectorIsEnabled = it);
140+
this.agenticSearchIsEnabled = ML_COMMONS_AGENTIC_SEARCH_ENABLED.get(clusterService.getSettings());
141+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> agenticSearchIsEnabled = it);
135142
}
136143

137144
@Override
@@ -394,6 +401,18 @@ private void executeAgent(
394401
listener.onFailure(new OpenSearchException(ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE));
395402
return;
396403
}
404+
List<MLToolSpec> tools = mlAgent.getTools();
405+
for (MLToolSpec tool : tools) {
406+
if (tool.getType().equals(QueryPlanningTool.TYPE)) {
407+
if (!agenticSearchIsEnabled) {
408+
listener.onFailure(new OpenSearchException(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE));
409+
return;
410+
} else {
411+
log.info("Searching for tool {}", tool.getName());
412+
}
413+
}
414+
}
415+
397416
MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent);
398417
// If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists
399418
if (isAsync) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import static org.mockito.Mockito.when;
1212
import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE;
1313
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
14+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
15+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
1416
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
1517
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MEMORY_ID;
1618
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;
@@ -36,6 +38,7 @@
3638
import org.mockito.Mock;
3739
import org.mockito.Mockito;
3840
import org.mockito.MockitoAnnotations;
41+
import org.opensearch.OpenSearchException;
3942
import org.opensearch.ResourceNotFoundException;
4043
import org.opensearch.Version;
4144
import org.opensearch.action.get.GetRequest;
@@ -171,7 +174,8 @@ public void setup() {
171174
when(client.threadPool()).thenReturn(threadPool);
172175
when(threadPool.getThreadContext()).thenReturn(threadContext);
173176
when(this.clusterService.getSettings()).thenReturn(settings);
174-
when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED)));
177+
when(this.clusterService.getClusterSettings())
178+
.thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED, ML_COMMONS_AGENTIC_SEARCH_ENABLED)));
175179

176180
settings = Settings.builder().build();
177181
mlAgentExecutor = Mockito
@@ -774,6 +778,67 @@ public void test_AsyncMode_IndexTask_failure() throws IOException {
774778
Assert.assertNotNull(exceptionCaptor.getValue());
775779
}
776780

781+
@Test
782+
public void test_query_planning_requires_agentic_search_enabled() throws IOException {
783+
// Create an MLAgent with QueryPlanningTool
784+
MLAgent mlAgentWithQueryPlanning = new MLAgent(
785+
"test",
786+
MLAgentType.FLOW.name(),
787+
"test",
788+
new LLMSpec("test_model", Map.of("test_key", "test_value")),
789+
List
790+
.of(
791+
new MLToolSpec(
792+
"QueryPlanningTool",
793+
"QueryPlanningTool",
794+
"QueryPlanningTool",
795+
Collections.emptyMap(),
796+
Collections.emptyMap(),
797+
false,
798+
Collections.emptyMap(),
799+
null,
800+
null
801+
)
802+
),
803+
Map.of("test", "test"),
804+
new MLMemorySpec("memoryType", "123", 0),
805+
Instant.EPOCH,
806+
Instant.EPOCH,
807+
"test",
808+
false,
809+
null
810+
);
811+
812+
// Create GetResponse with the MLAgent that has QueryPlanningTool
813+
XContentBuilder content = mlAgentWithQueryPlanning.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
814+
BytesReference bytesReference = BytesReference.bytes(content);
815+
GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null);
816+
GetResponse agentGetResponse = new GetResponse(getResult);
817+
818+
// Create a new MLAgentExecutor with agentic search disabled
819+
MLAgentExecutor mlAgentExecutorWithDisabledSearch = Mockito
820+
.spy(new MLAgentExecutor(client, sdkClient, settings, clusterService, xContentRegistry, toolFactories, memoryMap, false, null));
821+
822+
// Mock the agent get response
823+
Mockito.doAnswer(invocation -> {
824+
ActionListener<GetResponse> listener = invocation.getArgument(1);
825+
listener.onResponse(agentGetResponse);
826+
return null;
827+
}).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class));
828+
829+
// Mock the agent runner
830+
Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithDisabledSearch).getAgentRunner(Mockito.any());
831+
832+
// Execute the agent
833+
mlAgentExecutorWithDisabledSearch.execute(getAgentMLInput(), agentActionListener);
834+
835+
// Verify that the execution fails with the correct error message
836+
Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture());
837+
Exception exception = exceptionCaptor.getValue();
838+
Assert.assertTrue(exception instanceof OpenSearchException);
839+
Assert.assertEquals(exception.getMessage(), ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE);
840+
}
841+
777842
private AgentMLInput getAgentMLInput() {
778843
Map<String, String> params = new HashMap<>();
779844
params.put(MLAgentExecutor.MEMORY_ID, "memoryId");

plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
99
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
10+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
11+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
1012
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE;
1113
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
1214

1315
import java.time.Instant;
1416
import java.util.HashMap;
17+
import java.util.List;
1518
import java.util.Map;
1619

1720
import org.opensearch.OpenSearchException;
@@ -26,12 +29,14 @@
2629
import org.opensearch.core.action.ActionListener;
2730
import org.opensearch.ml.common.MLAgentType;
2831
import org.opensearch.ml.common.agent.MLAgent;
32+
import org.opensearch.ml.common.agent.MLToolSpec;
2933
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3034
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
3135
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
3236
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
3337
import org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner;
3438
import org.opensearch.ml.engine.indices.MLIndicesHandler;
39+
import org.opensearch.ml.engine.tools.QueryPlanningTool;
3540
import org.opensearch.ml.utils.RestActionUtils;
3641
import org.opensearch.ml.utils.TenantAwareHelper;
3742
import org.opensearch.remote.metadata.client.PutDataObjectRequest;
@@ -52,6 +57,7 @@ public class TransportRegisterAgentAction extends HandledTransportAction<ActionR
5257

5358
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
5459
private volatile boolean mcpConnectorIsEnabled;
60+
private volatile boolean agenticSearchIsEnabled;
5561

5662
@Inject
5763
public TransportRegisterAgentAction(
@@ -71,6 +77,8 @@ public TransportRegisterAgentAction(
7177
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
7278
this.mcpConnectorIsEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(clusterService.getSettings());
7379
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> mcpConnectorIsEnabled = it);
80+
this.agenticSearchIsEnabled = ML_COMMONS_AGENTIC_SEARCH_ENABLED.get(clusterService.getSettings());
81+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> agenticSearchIsEnabled = it);
7482
}
7583

7684
@Override
@@ -88,6 +96,19 @@ private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse
8896
listener.onFailure(new OpenSearchException(ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE));
8997
return;
9098
}
99+
100+
List<MLToolSpec> tools = agent.getTools();
101+
for (MLToolSpec tool : tools) {
102+
if (tool.getType().equals(QueryPlanningTool.TYPE)) {
103+
if (!agenticSearchIsEnabled) {
104+
listener.onFailure(new OpenSearchException(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE));
105+
return;
106+
} else {
107+
log.info("Searching for tool {}", tool.getName());
108+
}
109+
}
110+
}
111+
91112
Instant now = Instant.now();
92113
boolean isHiddenAgent = RestActionUtils.isSuperAdminUser(clusterService, client);
93114
MLAgent mlAgent = agent.toBuilder().createdTime(now).lastUpdateTime(now).isHidden(isHiddenAgent).build();

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,8 @@ public List<Setting<?>> getSettings() {
11601160
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
11611161
MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED,
11621162
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
1163-
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
1163+
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
1164+
MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED
11641165
);
11651166
return settings;
11661167
}

plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
import static org.mockito.Mockito.verify;
1313
import static org.mockito.Mockito.when;
1414
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
15+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
16+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
1517
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
1618
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.EXECUTOR_AGENT_ID_FIELD;
1719

1820
import java.io.IOException;
1921
import java.util.Collections;
2022
import java.util.HashMap;
23+
import java.util.List;
2124
import java.util.Map;
2225
import java.util.Set;
2326

@@ -40,6 +43,7 @@
4043
import org.opensearch.ml.common.MLAgentType;
4144
import org.opensearch.ml.common.agent.LLMSpec;
4245
import org.opensearch.ml.common.agent.MLAgent;
46+
import org.opensearch.ml.common.agent.MLToolSpec;
4347
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
4448
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
4549
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
@@ -101,7 +105,8 @@ public void setup() throws IOException {
101105
when(client.threadPool()).thenReturn(threadPool);
102106
when(threadPool.getThreadContext()).thenReturn(threadContext);
103107
when(clusterService.getSettings()).thenReturn(settings);
104-
when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED)));
108+
when(this.clusterService.getClusterSettings())
109+
.thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED, ML_COMMONS_AGENTIC_SEARCH_ENABLED)));
105110
transportRegisterAgentAction = new TransportRegisterAgentAction(
106111
transportService,
107112
actionFilters,
@@ -366,4 +371,38 @@ public void test_execute_registerAgent_PlanExecuteAndReflect_WithExecutorAgentId
366371

367372
verify(client, times(1)).index(any(), any());
368373
}
374+
375+
@Test
376+
public void test_execute_registerAgent_QueryPlanningTool_Validation() {
377+
// Create an MLAgent with QueryPlanningTool
378+
MLToolSpec queryPlanningTool = new MLToolSpec(
379+
"QueryPlanningTool",
380+
"QueryPlanningTool",
381+
"QueryPlanningTool",
382+
Collections.emptyMap(),
383+
Collections.emptyMap(),
384+
false,
385+
Collections.emptyMap(),
386+
null,
387+
null
388+
);
389+
390+
MLAgent mlAgent = MLAgent
391+
.builder()
392+
.name("agent")
393+
.type(MLAgentType.CONVERSATIONAL.name())
394+
.description("description")
395+
.llm(new LLMSpec("model_id", new HashMap<>()))
396+
.tools(List.of(queryPlanningTool))
397+
.build();
398+
399+
MLRegisterAgentRequest request = mock(MLRegisterAgentRequest.class);
400+
when(request.getMlAgent()).thenReturn(mlAgent);
401+
402+
transportRegisterAgentAction.doExecute(task, request, actionListener);
403+
404+
ArgumentCaptor<OpenSearchException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchException.class);
405+
verify(actionListener).onFailure(argumentCaptor.capture());
406+
assertEquals(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE, argumentCaptor.getValue().getMessage());
407+
}
369408
}

plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.rest;
77

8+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
89
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
910

1011
import java.io.IOException;
@@ -91,6 +92,8 @@ public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
9192
assertNotNull(agentId);
9293

9394
String query = "{\"parameters\": {\"query_text\": \"How many iris flowers of type setosa are there?\"}}";
95+
// enable agentic search
96+
updateClusterSettings(ML_COMMONS_AGENTIC_SEARCH_ENABLED.getKey(), true);
9497
Response response = executeAgent(agentId, query);
9598
String responseBody = TestHelper.httpEntityToString(response.getEntity());
9699

0 commit comments

Comments
 (0)