Skip to content

Commit e7a24eb

Browse files
add feature flag for agentic search
Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent 5aea15f commit e7a24eb

File tree

5 files changed

+95
-2
lines changed

5 files changed

+95
-2
lines changed

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ private MLCommonsSettings() {}
216216
public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting
217217
.boolSetting("plugins.ml_commons.memory_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
218218

219+
public static final Setting<Boolean> ML_COMMONS_AGENTIC_SEARCH_ENABLED = Setting
220+
.boolSetting("plugins.ml_commons.agentic_search_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
221+
public static final String ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE =
222+
"The Agentic Search feature is not enabled. To enable, please update the setting " + ML_COMMONS_AGENTIC_SEARCH_ENABLED.getKey();
223+
219224
public static final Setting<Boolean> ML_COMMONS_MCP_CONNECTOR_ENABLED = Setting
220225
.boolSetting("plugins.ml_commons.mcp_connector_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
221226
public static final String ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE =

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1515
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
1616
import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD;
17+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
18+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
1719
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE;
1820
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
1921
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
@@ -52,6 +54,7 @@
5254
import org.opensearch.ml.common.MLTaskType;
5355
import org.opensearch.ml.common.agent.MLAgent;
5456
import org.opensearch.ml.common.agent.MLMemorySpec;
57+
import org.opensearch.ml.common.agent.MLToolSpec;
5558
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5659
import org.opensearch.ml.common.input.Input;
5760
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
@@ -68,6 +71,7 @@
6871
import org.opensearch.ml.engine.encryptor.Encryptor;
6972
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
7073
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
74+
import org.opensearch.ml.engine.tools.QueryPlanningTool;
7175
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
7276
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
7377
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;
@@ -109,6 +113,7 @@ public class MLAgentExecutor implements Executable, SettingsChangeListener {
109113
private volatile Boolean isMultiTenancyEnabled;
110114
private Encryptor encryptor;
111115
private static volatile boolean mcpConnectorIsEnabled;
116+
private static volatile boolean agenticSearchIsEnabled;
112117

113118
public MLAgentExecutor(
114119
Client client,
@@ -132,6 +137,8 @@ public MLAgentExecutor(
132137
this.encryptor = encryptor;
133138
this.mcpConnectorIsEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(clusterService.getSettings());
134139
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> mcpConnectorIsEnabled = it);
140+
this.agenticSearchIsEnabled = ML_COMMONS_AGENTIC_SEARCH_ENABLED.get(clusterService.getSettings());
141+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> agenticSearchIsEnabled = it);
135142
}
136143

137144
@Override
@@ -394,6 +401,18 @@ private void executeAgent(
394401
listener.onFailure(new OpenSearchException(ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE));
395402
return;
396403
}
404+
List<MLToolSpec> tools = mlAgent.getTools();
405+
for (MLToolSpec tool : tools) {
406+
if (tool.getType().equals(QueryPlanningTool.TYPE)) {
407+
if (!agenticSearchIsEnabled) {
408+
listener.onFailure(new OpenSearchException(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE));
409+
return;
410+
} else {
411+
log.info("Searching for tool {}", tool.getName());
412+
}
413+
}
414+
}
415+
397416
MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent);
398417
// If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists
399418
if (isAsync) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import static org.mockito.Mockito.when;
1212
import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE;
1313
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
14+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
15+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
1416
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
1517
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MEMORY_ID;
1618
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;
@@ -36,6 +38,7 @@
3638
import org.mockito.Mock;
3739
import org.mockito.Mockito;
3840
import org.mockito.MockitoAnnotations;
41+
import org.opensearch.OpenSearchException;
3942
import org.opensearch.ResourceNotFoundException;
4043
import org.opensearch.Version;
4144
import org.opensearch.action.get.GetRequest;
@@ -171,7 +174,8 @@ public void setup() {
171174
when(client.threadPool()).thenReturn(threadPool);
172175
when(threadPool.getThreadContext()).thenReturn(threadContext);
173176
when(this.clusterService.getSettings()).thenReturn(settings);
174-
when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED)));
177+
when(this.clusterService.getClusterSettings())
178+
.thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED, ML_COMMONS_AGENTIC_SEARCH_ENABLED)));
175179

176180
settings = Settings.builder().build();
177181
mlAgentExecutor = Mockito
@@ -774,6 +778,67 @@ public void test_AsyncMode_IndexTask_failure() throws IOException {
774778
Assert.assertNotNull(exceptionCaptor.getValue());
775779
}
776780

781+
@Test
782+
public void test_query_planning_requires_agentic_search_enabled() throws IOException {
783+
// Create an MLAgent with QueryPlanningTool
784+
MLAgent mlAgentWithQueryPlanning = new MLAgent(
785+
"test",
786+
MLAgentType.FLOW.name(),
787+
"test",
788+
new LLMSpec("test_model", Map.of("test_key", "test_value")),
789+
List
790+
.of(
791+
new MLToolSpec(
792+
"QueryPlanningTool",
793+
"QueryPlanningTool",
794+
"QueryPlanningTool",
795+
Collections.emptyMap(),
796+
Collections.emptyMap(),
797+
false,
798+
Collections.emptyMap(),
799+
null,
800+
null
801+
)
802+
),
803+
Map.of("test", "test"),
804+
new MLMemorySpec("memoryType", "123", 0),
805+
Instant.EPOCH,
806+
Instant.EPOCH,
807+
"test",
808+
false,
809+
null
810+
);
811+
812+
// Create GetResponse with the MLAgent that has QueryPlanningTool
813+
XContentBuilder content = mlAgentWithQueryPlanning.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
814+
BytesReference bytesReference = BytesReference.bytes(content);
815+
GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null);
816+
GetResponse agentGetResponse = new GetResponse(getResult);
817+
818+
// Create a new MLAgentExecutor with agentic search disabled
819+
MLAgentExecutor mlAgentExecutorWithDisabledSearch = Mockito
820+
.spy(new MLAgentExecutor(client, sdkClient, settings, clusterService, xContentRegistry, toolFactories, memoryMap, false, null));
821+
822+
// Mock the agent get response
823+
Mockito.doAnswer(invocation -> {
824+
ActionListener<GetResponse> listener = invocation.getArgument(1);
825+
listener.onResponse(agentGetResponse);
826+
return null;
827+
}).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class));
828+
829+
// Mock the agent runner
830+
Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithDisabledSearch).getAgentRunner(Mockito.any());
831+
832+
// Execute the agent
833+
mlAgentExecutorWithDisabledSearch.execute(getAgentMLInput(), agentActionListener);
834+
835+
// Verify that the execution fails with the correct error message
836+
Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture());
837+
Exception exception = exceptionCaptor.getValue();
838+
Assert.assertTrue(exception instanceof OpenSearchException);
839+
Assert.assertEquals(exception.getMessage(), ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE);
840+
}
841+
777842
private AgentMLInput getAgentMLInput() {
778843
Map<String, String> params = new HashMap<>();
779844
params.put(MLAgentExecutor.MEMORY_ID, "memoryId");

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,8 @@ public List<Setting<?>> getSettings() {
11601160
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
11611161
MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED,
11621162
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
1163-
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
1163+
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
1164+
MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED
11641165
);
11651166
return settings;
11661167
}

plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.rest;
77

8+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
89
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
910

1011
import java.io.IOException;
@@ -86,6 +87,8 @@ public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
8687
assertNotNull(agentId);
8788

8889
String query = "{\"parameters\": {\"query_text\": \"How many iris flowers of type setosa are there?\"}}";
90+
// enable agentic search
91+
updateClusterSettings(ML_COMMONS_AGENTIC_SEARCH_ENABLED.getKey(), true);
8992
Response response = executeAgent(agentId, query);
9093
String responseBody = TestHelper.httpEntityToString(response.getEntity());
9194

0 commit comments

Comments
 (0)