1616import static org .mockito .Mockito .mock ;
1717import static org .mockito .Mockito .verify ;
1818import static org .opensearch .ml .engine .tools .QueryPlanningTool .DEFAULT_DESCRIPTION ;
19+ import static org .opensearch .ml .engine .tools .QueryPlanningTool .INDEX_MAPPING_FIELD ;
1920import static org .opensearch .ml .engine .tools .QueryPlanningTool .MODEL_ID_FIELD ;
21+ import static org .opensearch .ml .engine .tools .QueryPlanningTool .QUERY_FIELDS_FIELD ;
2022import static org .opensearch .ml .engine .tools .QueryPlanningTool .SYSTEM_PROMPT_FIELD ;
2123
2224import java .util .Collections ;
@@ -59,7 +61,7 @@ public void setup() {
5961 MLModelTool .Factory .getInstance ().init (client );
6062 factory = new QueryPlanningTool .Factory ();
6163 validParams = new HashMap <>();
62- validParams .put ("prompt" , "test prompt" );
64+ validParams .put (SYSTEM_PROMPT_FIELD , "test prompt" );
6365 emptyParams = Collections .emptyMap ();
6466 }
6567
@@ -85,7 +87,7 @@ public void testRun() throws ExecutionException, InterruptedException {
8587 ActionListener <String > listener = ActionListener .wrap (future ::complete , future ::completeExceptionally );
8688 // test try to update the prompt
8789 validParams
88- .put ("prompt" , "You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}" );
90+ .put (SYSTEM_PROMPT_FIELD , "You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}" );
8991 validParams .put ("query_text" , "help me find some books related to wind" );
9092 tool .run (validParams , listener );
9193
@@ -203,7 +205,7 @@ public void testRunWithNoPrompt() {
203205 ArgumentCaptor <Map <String , String >> captor = ArgumentCaptor .forClass (Map .class );
204206 doAnswer (invocation -> {
205207 Map <String , String > params = invocation .getArgument (0 );
206- assertNotNull (params .get ("prompt" ));
208+ assertNotNull (params .get (SYSTEM_PROMPT_FIELD ));
207209 return null ;
208210 }).when (queryGenerationTool ).run (captor .capture (), any ());
209211 }
@@ -274,8 +276,8 @@ public void testAllParameterProcessing() {
274276 QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
275277 Map <String , String > parameters = new HashMap <>();
276278 parameters .put ("query_text" , "test query" );
277- parameters .put ("index_mapping" , "{\" properties\" :{\" title\" :{\" type\" :\" text\" }}}" );
278- parameters .put ("query_fields" , "[\" title\" , \" content\" ]" );
279+ parameters .put (INDEX_MAPPING_FIELD , "{\" properties\" :{\" title\" :{\" type\" :\" text\" }}}" );
280+ parameters .put (QUERY_FIELDS_FIELD , "[\" title\" , \" content\" ]" );
279281 // No system_prompt - should use default
280282
281283 @ SuppressWarnings ("unchecked" )
@@ -296,12 +298,12 @@ public void testAllParameterProcessing() {
296298
297299 // All parameters should be processed
298300 assertTrue (capturedParams .containsKey ("query_text" ));
299- assertTrue (capturedParams .containsKey ("index_mapping" ));
300- assertTrue (capturedParams .containsKey ("query_fields" ));
301+ assertTrue (capturedParams .containsKey (INDEX_MAPPING_FIELD ));
302+ assertTrue (capturedParams .containsKey (QUERY_FIELDS_FIELD ));
301303 assertTrue (capturedParams .containsKey (SYSTEM_PROMPT_FIELD ));
302304
303305 // Processed parameters should be JSON strings
304- assertTrue (capturedParams .get ("index_mapping" ).startsWith ("\" " ));
305- assertTrue (capturedParams .get ("query_fields" ).startsWith ("\" " ));
306+ assertTrue (capturedParams .get (INDEX_MAPPING_FIELD ).startsWith ("\" " ));
307+ assertTrue (capturedParams .get (QUERY_FIELDS_FIELD ).startsWith ("\" " ));
306308 }
307309}
0 commit comments