Skip to content

Commit b6f4931

Browse files
add test coverage
Signed-off-by: rithin-pullela-aws <rithinp@amazon.com>
1 parent e63a8b9 commit b6f4931

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ public class QueryPlanningTool implements WithModelTool {
3535
public static final String MODEL_ID_FIELD = "model_id";
3636
private final MLModelTool queryGenerationTool;
3737
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
38+
public static final String INDEX_MAPPING_FIELD = "index_mapping";
39+
public static final String QUERY_FIELDS_FIELD = "query_fields";
3840
private static final String GENERATION_TYPE_FIELD = "generation_type";
3941
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
4042
@Getter
@@ -66,11 +68,11 @@ public <T> void run(Map<String, String> parameters, ActionListener<T> listener)
6668
if (!parameters.containsKey(SYSTEM_PROMPT_FIELD)) {
6769
parameters.put(SYSTEM_PROMPT_FIELD, DEFAULT_SYSTEM_PROMPT);
6870
}
69-
if (parameters.containsKey("index_mapping")) {
70-
parameters.put("index_mapping", gson.toJson(parameters.get("index_mapping")));
71+
if (parameters.containsKey(INDEX_MAPPING_FIELD)) {
72+
parameters.put(INDEX_MAPPING_FIELD, gson.toJson(parameters.get(INDEX_MAPPING_FIELD)));
7173
}
72-
if (parameters.containsKey("query_fields")) {
73-
parameters.put("query_fields", gson.toJson(parameters.get("query_fields")));
74+
if (parameters.containsKey(QUERY_FIELDS_FIELD)) {
75+
parameters.put(QUERY_FIELDS_FIELD, gson.toJson(parameters.get(QUERY_FIELDS_FIELD)));
7476
}
7577
ActionListener<T> modelListener = ActionListener.wrap(r -> {
7678
try {

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import static org.mockito.Mockito.mock;
1717
import static org.mockito.Mockito.verify;
1818
import static org.opensearch.ml.engine.tools.QueryPlanningTool.DEFAULT_DESCRIPTION;
19+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.INDEX_MAPPING_FIELD;
1920
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
21+
import static org.opensearch.ml.engine.tools.QueryPlanningTool.QUERY_FIELDS_FIELD;
2022
import static org.opensearch.ml.engine.tools.QueryPlanningTool.SYSTEM_PROMPT_FIELD;
2123

2224
import java.util.Collections;
@@ -59,7 +61,7 @@ public void setup() {
5961
MLModelTool.Factory.getInstance().init(client);
6062
factory = new QueryPlanningTool.Factory();
6163
validParams = new HashMap<>();
62-
validParams.put("prompt", "test prompt");
64+
validParams.put(SYSTEM_PROMPT_FIELD, "test prompt");
6365
emptyParams = Collections.emptyMap();
6466
}
6567

@@ -85,7 +87,7 @@ public void testRun() throws ExecutionException, InterruptedException {
8587
ActionListener<String> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
8688
// test try to update the prompt
8789
validParams
88-
.put("prompt", "You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}");
90+
.put(SYSTEM_PROMPT_FIELD, "You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}");
8991
validParams.put("query_text", "help me find some books related to wind");
9092
tool.run(validParams, listener);
9193

@@ -203,7 +205,7 @@ public void testRunWithNoPrompt() {
203205
ArgumentCaptor<Map<String, String>> captor = ArgumentCaptor.forClass(Map.class);
204206
doAnswer(invocation -> {
205207
Map<String, String> params = invocation.getArgument(0);
206-
assertNotNull(params.get("prompt"));
208+
assertNotNull(params.get(SYSTEM_PROMPT_FIELD));
207209
return null;
208210
}).when(queryGenerationTool).run(captor.capture(), any());
209211
}
@@ -274,8 +276,8 @@ public void testAllParameterProcessing() {
274276
QueryPlanningTool tool = new QueryPlanningTool("llmGenerated", queryGenerationTool);
275277
Map<String, String> parameters = new HashMap<>();
276278
parameters.put("query_text", "test query");
277-
parameters.put("index_mapping", "{\"properties\":{\"title\":{\"type\":\"text\"}}}");
278-
parameters.put("query_fields", "[\"title\", \"content\"]");
279+
parameters.put(INDEX_MAPPING_FIELD, "{\"properties\":{\"title\":{\"type\":\"text\"}}}");
280+
parameters.put(QUERY_FIELDS_FIELD, "[\"title\", \"content\"]");
279281
// No system_prompt - should use default
280282

281283
@SuppressWarnings("unchecked")
@@ -296,12 +298,12 @@ public void testAllParameterProcessing() {
296298

297299
// All parameters should be processed
298300
assertTrue(capturedParams.containsKey("query_text"));
299-
assertTrue(capturedParams.containsKey("index_mapping"));
300-
assertTrue(capturedParams.containsKey("query_fields"));
301+
assertTrue(capturedParams.containsKey(INDEX_MAPPING_FIELD));
302+
assertTrue(capturedParams.containsKey(QUERY_FIELDS_FIELD));
301303
assertTrue(capturedParams.containsKey(SYSTEM_PROMPT_FIELD));
302304

303305
// Processed parameters should be JSON strings
304-
assertTrue(capturedParams.get("index_mapping").startsWith("\""));
305-
assertTrue(capturedParams.get("query_fields").startsWith("\""));
306+
assertTrue(capturedParams.get(INDEX_MAPPING_FIELD).startsWith("\""));
307+
assertTrue(capturedParams.get(QUERY_FIELDS_FIELD).startsWith("\""));
306308
}
307309
}

0 commit comments

Comments
 (0)