|
22 | 22 | import static org.mockito.Mockito.mock;
|
23 | 23 | import static org.mockito.Mockito.verify;
|
24 | 24 | import static org.mockito.Mockito.when;
|
| 25 | +import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG; |
25 | 26 | import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor.IllegalArgumentMessage;
|
26 | 27 |
|
27 | 28 | import java.time.Instant;
|
@@ -461,4 +462,52 @@ public void testProcessResponseNullValueInteractions() throws Exception {
|
461 | 462 |
|
462 | 463 | SearchResponse res = processor.processResponse(request, response);
|
463 | 464 | }
|
| 465 | + |
| 466 | + public void testProcessResponseIllegalArgumentForNullParams() throws Exception { |
| 467 | + exceptionRule.expect(IllegalArgumentException.class); |
| 468 | + exceptionRule.expectMessage(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG); |
| 469 | + |
| 470 | + Client client = mock(Client.class); |
| 471 | + Map<String, Object> config = new HashMap<>(); |
| 472 | + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); |
| 473 | + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); |
| 474 | + |
| 475 | + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( |
| 476 | + client, |
| 477 | + alwaysOn |
| 478 | + ).create(null, "tag", "desc", true, config, null); |
| 479 | + |
| 480 | + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); |
| 481 | + when(memoryClient.getInteractions(any(), anyInt())) |
| 482 | + .thenReturn(List.of(new Interaction("0", Instant.now(), "1", null, null, null, null, null))); |
| 483 | + processor.setMemoryClient(memoryClient); |
| 484 | + |
| 485 | + SearchRequest request = new SearchRequest(); |
| 486 | + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); |
| 487 | + |
| 488 | + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); |
| 489 | + extBuilder.setParams(null); |
| 490 | + request.source(sourceBuilder); |
| 491 | + sourceBuilder.ext(List.of(extBuilder)); |
| 492 | + |
| 493 | + int numHits = 10; |
| 494 | + SearchHit[] hitsArray = new SearchHit[numHits]; |
| 495 | + for (int i = 0; i < numHits; i++) { |
| 496 | + XContentBuilder sourceContent = JsonXContent |
| 497 | + .contentBuilder() |
| 498 | + .startObject() |
| 499 | + .field("_id", String.valueOf(i)) |
| 500 | + .field("text", "passage" + i) |
| 501 | + .field("title", "This is the title for document " + i) |
| 502 | + .endObject(); |
| 503 | + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); |
| 504 | + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); |
| 505 | + } |
| 506 | + |
| 507 | + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); |
| 508 | + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); |
| 509 | + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); |
| 510 | + |
| 511 | + SearchResponse res = processor.processResponse(request, response); |
| 512 | + } |
464 | 513 | }
|
0 commit comments