|
8 | 8 | import static org.junit.Assert.assertEquals; |
9 | 9 | import static org.junit.Assert.assertNotEquals; |
10 | 10 | import static org.junit.Assert.assertNotNull; |
| 11 | +import static org.junit.Assert.assertTrue; |
11 | 12 | import static org.junit.Assert.fail; |
| 13 | +import static org.mockito.ArgumentMatchers.any; |
| 14 | +import static org.mockito.Mockito.mock; |
| 15 | +import static org.mockito.Mockito.verify; |
| 16 | +import static org.mockito.Mockito.when; |
| 17 | +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SDK_CLIENT; |
| 18 | +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SETTINGS; |
12 | 19 | import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame; |
13 | 20 | import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame; |
14 | 21 | import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; |
|
19 | 26 | import java.util.Arrays; |
20 | 27 | import java.util.Collections; |
21 | 28 | import java.util.List; |
| 29 | +import java.util.Locale; |
22 | 30 | import java.util.Map; |
23 | 31 | import java.util.UUID; |
| 32 | +import java.util.concurrent.CompletableFuture; |
24 | 33 |
|
25 | 34 | import org.junit.Assert; |
26 | 35 | import org.junit.Before; |
27 | 36 | import org.junit.Rule; |
28 | 37 | import org.junit.Test; |
29 | 38 | import org.junit.rules.ExpectedException; |
| 39 | +import org.mockito.ArgumentCaptor; |
30 | 40 | import org.mockito.MockedStatic; |
31 | 41 | import org.opensearch.common.settings.Settings; |
32 | 42 | import org.opensearch.common.xcontent.XContentType; |
|
37 | 47 | import org.opensearch.core.xcontent.XContentParser; |
38 | 48 | import org.opensearch.ml.common.FunctionName; |
39 | 49 | import org.opensearch.ml.common.MLModel; |
| 50 | +import org.opensearch.ml.common.connector.AwsConnector; |
40 | 51 | import org.opensearch.ml.common.connector.HttpConnector; |
41 | 52 | import org.opensearch.ml.common.dataframe.ColumnMeta; |
42 | 53 | import org.opensearch.ml.common.dataframe.DataFrame; |
|
57 | 68 | import org.opensearch.ml.engine.algorithms.regression.LinearRegression; |
58 | 69 | import org.opensearch.ml.engine.encryptor.Encryptor; |
59 | 70 | import org.opensearch.ml.engine.encryptor.EncryptorImpl; |
| 71 | +import org.opensearch.remote.metadata.client.SdkClient; |
60 | 72 | import org.opensearch.search.SearchModule; |
61 | 73 |
|
| 74 | +import software.amazon.awssdk.utils.ImmutableMap; |
| 75 | + |
62 | 76 | // TODO: refactor MLEngineClassLoader's static functions to avoid mockStatic |
63 | 77 | public class MLEngineTest extends MLStaticMockBase { |
64 | 78 | @Rule |
@@ -523,4 +537,74 @@ public void testGetConnectorCredentialWithoutRegion() throws IOException { |
523 | 537 | assertEquals("test_key_value", decryptedCredential.get("key")); |
524 | 538 | assertEquals(null, decryptedCredential.get("region")); |
525 | 539 | } |
| 540 | + |
| 541 | + @Test |
| 542 | + public void testDeploy_withPredictableActionListener_successful() throws IOException { |
| 543 | + String encryptedAccessKey = mlEngine.encrypt("access-key", null); |
| 544 | + String encryptedSecretKey = mlEngine.encrypt("secret-key", null); |
| 545 | + String testConnector = String.format(Locale.ROOT, """ |
| 546 | + { |
| 547 | + "name": "sagemaker: t2ppl", |
| 548 | + "description": "t2ppl model", |
| 549 | + "version": 1, |
| 550 | + "protocol": "aws_sigv4", |
| 551 | + "credential": { |
| 552 | + "access_key": "%s", |
| 553 | + "secret_key": "%s" |
| 554 | + }, |
| 555 | + "parameters": { |
| 556 | + "region": "us-east-1", |
| 557 | + "service_name": "sagemaker", |
| 558 | + "input_type": "search_document" |
| 559 | + }, |
| 560 | + "actions": [ |
| 561 | + { |
| 562 | + "action_type": "predict", |
| 563 | + "method": "POST", |
| 564 | + "headers": { |
| 565 | + "content-type": "application/json", |
| 566 | + "x-amz-content-sha256": "required" |
| 567 | + }, |
| 568 | + "url": "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/my-endpoint/invocations", |
| 569 | + "request_body": "{\\"prompt\\":\\"${parameters.prompt}\\"}" |
| 570 | + } |
| 571 | + ] |
| 572 | + } |
| 573 | + """, encryptedAccessKey, encryptedSecretKey); |
| 574 | + |
| 575 | + XContentParser parser = XContentType.JSON |
| 576 | + .xContent() |
| 577 | + .createParser( |
| 578 | + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), |
| 579 | + null, |
| 580 | + testConnector |
| 581 | + ); |
| 582 | + parser.nextToken(); |
| 583 | + |
| 584 | + MLModel model = mock(MLModel.class); |
| 585 | + AwsConnector connector = new AwsConnector("aws_sigv4", parser); |
| 586 | + when(model.getAlgorithm()).thenReturn(FunctionName.REMOTE); |
| 587 | + when(model.getConnector()).thenReturn(connector); |
| 588 | + ActionListener<Predictable> actionListener = mock(ActionListener.class); |
| 589 | + SdkClient sdkClient = mock(SdkClient.class); |
| 590 | + when(sdkClient.isGlobalResource(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); |
| 591 | + Map<String, Object> params = ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, Settings.EMPTY); |
| 592 | + mlEngine.deploy(model, params, actionListener); |
| 593 | + verify(actionListener).onResponse(any(Predictable.class)); |
| 594 | + } |
| 595 | + |
| 596 | + @Test |
| 597 | + public void testDeploy_withPredictableActionListener_exceptional() { |
| 598 | + MLModel model = mock(MLModel.class); |
| 599 | + when(model.getAlgorithm()).thenReturn(FunctionName.REMOTE); |
| 600 | + when(model.getConnector()).thenThrow(new RuntimeException("Runtime error")); |
| 601 | + ActionListener<Predictable> actionListener = mock(ActionListener.class); |
| 602 | + SdkClient sdkClient = mock(SdkClient.class); |
| 603 | + when(sdkClient.isGlobalResource(any(), any())).thenReturn(CompletableFuture.completedFuture(false)); |
| 604 | + Map<String, Object> params = ImmutableMap.of(SDK_CLIENT, sdkClient, SETTINGS, Settings.EMPTY); |
| 605 | + mlEngine.deploy(model, params, actionListener); |
| 606 | + ArgumentCaptor<RuntimeException> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); |
| 607 | + verify(actionListener).onFailure(argumentCaptor.capture()); |
| 608 | + assertTrue(argumentCaptor.getValue().getMessage().contains("Runtime error")); |
| 609 | + } |
526 | 610 | } |
0 commit comments