Skip to content

Commit 0c3955d

Browse files
Prevent qp tool create when not enabled
Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent 74d496e commit 0c3955d

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55

66
package org.opensearch.ml.engine.tools;
77

8+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
9+
810
import java.util.List;
911
import java.util.Map;
1012

1113
import org.apache.commons.text.StringSubstitutor;
14+
import org.opensearch.OpenSearchException;
1215
import org.opensearch.core.action.ActionListener;
16+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
1317
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
1418
import org.opensearch.ml.common.spi.tools.WithModelTool;
1519
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
@@ -108,6 +112,7 @@ public boolean validate(Map<String, String> parameters) {
108112
public static class Factory implements WithModelTool.Factory<QueryPlanningTool> {
109113
private Client client;
110114
private static volatile Factory INSTANCE;
115+
private static MLFeatureEnabledSetting mlFeatureEnabledSetting;
111116

112117
public static Factory getInstance() {
113118
if (INSTANCE != null) {
@@ -122,13 +127,18 @@ public static Factory getInstance() {
122127
}
123128
}
124129

125-
public void init(Client client) {
130+
public void init(Client client, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
126131
this.client = client;
132+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
127133
}
128134

129135
@Override
130136
public QueryPlanningTool create(Map<String, Object> map) {
131137

138+
if (!mlFeatureEnabledSetting.isAgenticSearchEnabled()) {
139+
throw new OpenSearchException(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE);
140+
}
141+
132142
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
133143

134144
String type = (String) map.get(GENERATION_TYPE_FIELD);

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import static org.mockito.ArgumentMatchers.any;
1515
import static org.mockito.Mockito.doAnswer;
1616
import static org.mockito.Mockito.mock;
17+
import static org.mockito.Mockito.when;
18+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
1719
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
1820
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
1921

@@ -31,7 +33,9 @@
3133
import org.mockito.ArgumentCaptor;
3234
import org.mockito.Mock;
3335
import org.mockito.MockitoAnnotations;
36+
import org.opensearch.OpenSearchException;
3437
import org.opensearch.core.action.ActionListener;
38+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3539
import org.opensearch.ml.common.spi.tools.Tool;
3640
import org.opensearch.transport.client.Client;
3741

@@ -46,6 +50,9 @@ public class QueryPlanningToolTests {
4650
@Mock
4751
private MLModelTool queryGenerationTool;
4852

53+
@Mock
54+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
55+
4956
private Map<String, String> validParams;
5057
private Map<String, String> emptyParams;
5158

@@ -55,7 +62,14 @@ public class QueryPlanningToolTests {
5562
public void setup() {
5663
MockitoAnnotations.openMocks(this);
5764
MLModelTool.Factory.getInstance().init(client);
58-
factory = new QueryPlanningTool.Factory();
65+
66+
// Mock the MLFeatureEnabledSetting to return true for agentic search
67+
when(mlFeatureEnabledSetting.isAgenticSearchEnabled()).thenReturn(true);
68+
69+
// Initialize the factory with mocked dependencies
70+
factory = QueryPlanningTool.Factory.getInstance();
71+
factory.init(client, mlFeatureEnabledSetting);
72+
5973
validParams = new HashMap<>();
6074
validParams.put("prompt", "test prompt");
6175
emptyParams = Collections.emptyMap();
@@ -269,4 +283,16 @@ public void testFactoryCreateWithInvalidType() {
269283
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(map));
270284
assertEquals("Invalid generation type: invalid. The current supported types are llmGenerated.", exception.getMessage());
271285
}
286+
287+
@Test
288+
public void testFactoryCreateWhenAgenticSearchDisabled() {
289+
// Mock the MLFeatureEnabledSetting to return false for agentic search
290+
when(mlFeatureEnabledSetting.isAgenticSearchEnabled()).thenReturn(false);
291+
292+
Map<String, Object> map = new HashMap<>();
293+
map.put(QueryPlanningTool.MODEL_ID_FIELD, "modelId");
294+
295+
Exception exception = assertThrows(OpenSearchException.class, () -> factory.create(map));
296+
assertEquals(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE, exception.getMessage());
297+
}
272298
}

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ public Collection<Object> createComponents(
730730
SearchIndexTool.Factory.getInstance().init(client, xContentRegistry);
731731
VisualizationsTool.Factory.getInstance().init(client);
732732
ConnectorTool.Factory.getInstance().init(client);
733-
QueryPlanningTool.Factory.getInstance().init(client);
733+
QueryPlanningTool.Factory.getInstance().init(client, mlFeatureEnabledSetting);
734734

735735
toolFactories.put(MLModelTool.TYPE, MLModelTool.Factory.getInstance());
736736
toolFactories.put(McpSseTool.TYPE, McpSseTool.Factory.getInstance());
@@ -1179,7 +1179,9 @@ public List<Setting<?>> getSettings() {
11791179
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
11801180
MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED,
11811181
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
1182-
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED
1182+
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
1183+
MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED,
1184+
MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED
11831185
);
11841186
return settings;
11851187
}

0 commit comments

Comments
 (0)