99import static org .junit .Assert .assertFalse ;
1010import static org .junit .Assert .assertNotNull ;
1111import static org .junit .Assert .assertNull ;
12+ import static org .junit .Assert .assertThrows ;
1213import static org .junit .Assert .assertTrue ;
1314import static org .mockito .ArgumentMatchers .any ;
1415import static org .mockito .Mockito .doAnswer ;
16+ import static org .mockito .Mockito .mock ;
1517import static org .opensearch .ml .engine .tools .QueryPlanningTool .DEFAULT_DESCRIPTION ;
1618import static org .opensearch .ml .engine .tools .QueryPlanningTool .MODEL_ID_FIELD ;
1719
2628import org .junit .Rule ;
2729import org .junit .Test ;
2830import org .junit .rules .ExpectedException ;
31+ import org .mockito .ArgumentCaptor ;
2932import org .mockito .Mock ;
3033import org .mockito .MockitoAnnotations ;
3134import org .opensearch .core .action .ActionListener ;
@@ -46,11 +49,13 @@ public class QueryPlanningToolTests {
4649 private Map <String , String > validParams ;
4750 private Map <String , String > emptyParams ;
4851
52+ private QueryPlanningTool .Factory factory ;
53+
4954 @ Before
5055 public void setup () {
5156 MockitoAnnotations .openMocks (this );
5257 MLModelTool .Factory .getInstance ().init (client );
53- QueryPlanningTool .Factory . getInstance (). init ( client );
58+ factory = new QueryPlanningTool .Factory ( );
5459 validParams = new HashMap <>();
5560 validParams .put ("prompt" , "test prompt" );
5661 emptyParams = Collections .emptyMap ();
@@ -73,7 +78,7 @@ public void testRun() throws ExecutionException, InterruptedException {
7378 return null ;
7479 }).when (queryGenerationTool ).run (any (), any ());
7580
76- QueryPlanningTool tool = new QueryPlanningTool (client , "test_model_id" , "llmGenerated" , queryGenerationTool );
81+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
7782 final CompletableFuture <String > future = new CompletableFuture <>();
7883 ActionListener <String > listener = ActionListener .wrap (future ::complete , future ::completeExceptionally );
7984 // test try to update the prompt
@@ -97,7 +102,7 @@ public void testRun_PredictionReturnsList_ThrowsIllegalArgumentException() throw
97102 return null ;
98103 }).when (queryGenerationTool ).run (any (), any ());
99104
100- QueryPlanningTool tool = new QueryPlanningTool (client , "test_model_id" , "llmGenerated" , queryGenerationTool );
105+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
101106 final CompletableFuture <String > future = new CompletableFuture <>();
102107 ActionListener <String > listener = ActionListener .wrap (future ::complete , future ::completeExceptionally );
103108 validParams .put ("query_text" , "help me find some books related to wind" );
@@ -114,7 +119,7 @@ public void testRun_PredictionReturnsNull_ReturnDefaultQuery() throws ExecutionE
114119 return null ;
115120 }).when (queryGenerationTool ).run (any (), any ());
116121
117- QueryPlanningTool tool = new QueryPlanningTool (client , "test_model_id" , "llmGenerated" , queryGenerationTool );
122+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
118123 final CompletableFuture <String > future = new CompletableFuture <>();
119124 ActionListener <String > listener = ActionListener .wrap (future ::complete , future ::completeExceptionally );
120125 validParams .put ("query_text" , "help me find some books related to wind" );
@@ -132,7 +137,25 @@ public void testRun_PredictionReturnsEmpty_ReturnDefaultQuery() throws Execution
132137 return null ;
133138 }).when (queryGenerationTool ).run (any (), any ());
134139
135- QueryPlanningTool tool = new QueryPlanningTool (client , "test_model_id" , "llmGenerated" , queryGenerationTool );
140+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
141+ final CompletableFuture <String > future = new CompletableFuture <>();
142+ ActionListener <String > listener = ActionListener .wrap (future ::complete , future ::completeExceptionally );
143+ validParams .put ("query_text" , "help me find some books related to wind" );
144+ tool .run (validParams , listener );
145+ String multiMatchQueryString =
146+ "{ \" query\" : { \" multi_match\" : { \" query\" : \" help me find some books related to wind\" , \" fields\" : [\" *\" ] } } }" ;
147+ assertEquals (multiMatchQueryString , future .get ());
148+ }
149+
150+ @ Test
151+ public void testRun_PredictionReturnsNullString_ReturnDefaultQuery () throws ExecutionException , InterruptedException {
152+ doAnswer (invocation -> {
153+ ActionListener <String > listener = invocation .getArgument (1 );
154+ listener .onResponse ("null" );
155+ return null ;
156+ }).when (queryGenerationTool ).run (any (), any ());
157+
158+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
136159 final CompletableFuture <String > future = new CompletableFuture <>();
137160 ActionListener <String > listener = ActionListener .wrap (future ::complete , future ::completeExceptionally );
138161 validParams .put ("query_text" , "help me find some books related to wind" );
@@ -168,4 +191,82 @@ public void testFactoryGetAllModelKeys() {
168191 @ Rule
169192 public ExpectedException thrown = ExpectedException .none ();
170193
194+ @ Test
195+ public void testRunWithNoPrompt () {
196+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
197+ Map <String , String > parameters = new HashMap <>();
198+ parameters .put ("query_text" , "some query" );
199+ @ SuppressWarnings ("unchecked" )
200+ ActionListener <String > listener = mock (ActionListener .class );
201+
202+ tool .run (parameters , listener );
203+
204+ ArgumentCaptor <Map <String , String >> captor = ArgumentCaptor .forClass (Map .class );
205+ doAnswer (invocation -> {
206+ Map <String , String > params = invocation .getArgument (0 );
207+ assertNotNull (params .get ("prompt" ));
208+ return null ;
209+ }).when (queryGenerationTool ).run (captor .capture (), any ());
210+ }
211+
212+ @ Test
213+ public void testRunWithInvalidParameters () {
214+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
215+ @ SuppressWarnings ("unchecked" )
216+ ActionListener <String > listener = mock (ActionListener .class );
217+
218+ tool .run (Collections .emptyMap (), listener );
219+
220+ ArgumentCaptor <Exception > captor = ArgumentCaptor .forClass (Exception .class );
221+ org .mockito .Mockito .verify (listener ).onFailure (captor .capture ());
222+ assertEquals ("Empty parameters for QueryPlanningTool: {}" , captor .getValue ().getMessage ());
223+ }
224+
225+ @ Test
226+ public void testRunModelReturnsNull () {
227+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
228+ Map <String , String > parameters = new HashMap <>();
229+ parameters .put ("query_text" , "some query" );
230+ @ SuppressWarnings ("unchecked" )
231+ ActionListener <String > listener = mock (ActionListener .class );
232+
233+ doAnswer (invocation -> {
234+ ActionListener <String > modelListener = invocation .getArgument (1 );
235+ modelListener .onResponse (null );
236+ return null ;
237+ }).when (queryGenerationTool ).run (any (), any ());
238+
239+ tool .run (parameters , listener );
240+
241+ ArgumentCaptor <String > captor = ArgumentCaptor .forClass (String .class );
242+ org .mockito .Mockito .verify (listener ).onResponse (captor .capture ());
243+ assertNotNull (captor .getValue ());
244+ }
245+
246+ @ Test
247+ public void testSetName () {
248+ QueryPlanningTool tool = new QueryPlanningTool ("llmGenerated" , queryGenerationTool );
249+ tool .setName ("NewName" );
250+ assertEquals ("NewName" , tool .getName ());
251+ }
252+
253+ @ Test
254+ public void testFactoryCreateWithEmptyType () {
255+ Map <String , Object > map = new HashMap <>();
256+ map .put (QueryPlanningTool .MODEL_ID_FIELD , "modelId" );
257+ Tool tool = factory .create (map );
258+ assertEquals (QueryPlanningTool .TYPE , tool .getName ());
259+ assertEquals ("llmGenerated" , ((QueryPlanningTool ) tool ).getGenerationType ());
260+ assertNotNull (tool );
261+ }
262+
263+ @ Test
264+ public void testFactoryCreateWithInvalidType () {
265+ Map <String , Object > map = new HashMap <>();
266+ map .put ("generation_type" , "invalid" );
267+ map .put (QueryPlanningTool .MODEL_ID_FIELD , "modelId" );
268+
269+ Exception exception = assertThrows (IllegalArgumentException .class , () -> factory .create (map ));
270+ assertEquals ("Invalid generation type: invalid. The current supported types are llmGenerated." , exception .getMessage ());
271+ }
171272}
0 commit comments