|
13 | 13 | import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; |
14 | 14 | import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; |
15 | 15 | import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; |
| 16 | +import static org.opensearch.ml.settings.MLCommonsSettings.REKOGNITION_TRUST_ENDPOINT_REGEX; |
16 | 17 | import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; |
17 | 18 | import static org.opensearch.ml.utils.TestHelper.clusterSetting; |
18 | 19 |
|
@@ -118,7 +119,13 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { |
118 | 119 | private ArgumentCaptor<PutDataObjectRequest> putDataObjectRequestArgumentCaptor; |
119 | 120 |
|
120 | 121 | private static final List<String> TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList |
121 | | - .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); |
| 122 | + .of( |
| 123 | + "^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", |
| 124 | + "^https://api\\.openai\\.com/.*$", |
| 125 | + "^https://api\\.cohere\\.ai/.*$", |
| 126 | + REKOGNITION_TRUST_ENDPOINT_REGEX, |
| 127 | + "^https://api\\.deepseek\\.com/.*$" |
| 128 | + ); |
122 | 129 |
|
123 | 130 | @Before |
124 | 131 | public void setup() { |
@@ -539,4 +546,117 @@ public void test_execute_URL_notMatchingExpression_exception() { |
539 | 546 | argumentCaptor.getValue().getMessage() |
540 | 547 | ); |
541 | 548 | } |
| 549 | + |
| 550 | + public void test_connector_creation_success_deepseek() { |
| 551 | + TransportCreateConnectorAction action = new TransportCreateConnectorAction( |
| 552 | + transportService, |
| 553 | + actionFilters, |
| 554 | + mlIndicesHandler, |
| 555 | + client, |
| 556 | + sdkClient, |
| 557 | + mlEngine, |
| 558 | + connectorAccessControlHelper, |
| 559 | + settings, |
| 560 | + clusterService, |
| 561 | + mlModelManager, |
| 562 | + mlFeatureEnabledSetting |
| 563 | + ); |
| 564 | + doAnswer(invocation -> { |
| 565 | + ActionListener<Boolean> listener = invocation.getArgument(0); |
| 566 | + listener.onResponse(true); |
| 567 | + return null; |
| 568 | + }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); |
| 569 | + doAnswer(invocation -> { |
| 570 | + ActionListener<IndexResponse> listener = invocation.getArgument(1); |
| 571 | + listener.onResponse(indexResponse); |
| 572 | + return null; |
| 573 | + }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); |
| 574 | + List<ConnectorAction> actions = new ArrayList<>(); |
| 575 | + actions |
| 576 | + .add( |
| 577 | + ConnectorAction |
| 578 | + .builder() |
| 579 | + .actionType(ConnectorAction.ActionType.PREDICT) |
| 580 | + .method("POST") |
| 581 | + .url("https://api.deepseek.com/v1/chat/completions") |
| 582 | + .build() |
| 583 | + ); |
| 584 | + Map<String, String> credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); |
| 585 | + MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput |
| 586 | + .builder() |
| 587 | + .name(randomAlphaOfLength(5)) |
| 588 | + .description(randomAlphaOfLength(10)) |
| 589 | + .version("1") |
| 590 | + .protocol(ConnectorProtocols.HTTP) |
| 591 | + .credential(credential) |
| 592 | + .actions(actions) |
| 593 | + .build(); |
| 594 | + MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput); |
| 595 | + action.doExecute(task, request, actionListener); |
| 596 | + verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); |
| 597 | + } |
| 598 | + |
| 599 | + public void test_connector_creation_success_rekognition() { |
| 600 | + TransportCreateConnectorAction action = new TransportCreateConnectorAction( |
| 601 | + transportService, |
| 602 | + actionFilters, |
| 603 | + mlIndicesHandler, |
| 604 | + client, |
| 605 | + sdkClient, |
| 606 | + mlEngine, |
| 607 | + connectorAccessControlHelper, |
| 608 | + settings, |
| 609 | + clusterService, |
| 610 | + mlModelManager, |
| 611 | + mlFeatureEnabledSetting |
| 612 | + ); |
| 613 | + |
| 614 | + doAnswer(invocation -> { |
| 615 | + ActionListener<Boolean> listener = invocation.getArgument(0); |
| 616 | + listener.onResponse(true); |
| 617 | + return null; |
| 618 | + }).when(mlIndicesHandler).initMLConnectorIndex(isA(ActionListener.class)); |
| 619 | + |
| 620 | + doAnswer(invocation -> { |
| 621 | + ActionListener<IndexResponse> listener = invocation.getArgument(1); |
| 622 | + listener.onResponse(indexResponse); |
| 623 | + return null; |
| 624 | + }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); |
| 625 | + |
| 626 | + List<ConnectorAction> actions = new ArrayList<>(); |
| 627 | + actions |
| 628 | + .add( |
| 629 | + ConnectorAction |
| 630 | + .builder() |
| 631 | + .actionType(ConnectorAction.ActionType.PREDICT) |
| 632 | + .method("POST") |
| 633 | + .url("https://rekognition.test-region-1.amazonaws.com") |
| 634 | + .build() |
| 635 | + ); |
| 636 | + actions |
| 637 | + .add( |
| 638 | + ConnectorAction |
| 639 | + .builder() |
| 640 | + .actionType(ConnectorAction.ActionType.PREDICT) |
| 641 | + .method("POST") |
| 642 | + .url("https://rekognition-fips.test-region-1.amazonaws.com") |
| 643 | + .build() |
| 644 | + ); |
| 645 | + |
| 646 | + Map<String, String> credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); |
| 647 | + MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput |
| 648 | + .builder() |
| 649 | + .name(randomAlphaOfLength(5)) |
| 650 | + .description(randomAlphaOfLength(10)) |
| 651 | + .version("1") |
| 652 | + .protocol(ConnectorProtocols.HTTP) |
| 653 | + .credential(credential) |
| 654 | + .actions(actions) |
| 655 | + .build(); |
| 656 | + |
| 657 | + MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput); |
| 658 | + |
| 659 | + action.doExecute(task, request, actionListener); |
| 660 | + verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); |
| 661 | + } |
542 | 662 | } |
0 commit comments