Skip to content

Commit

Permalink
Fix UT failures
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Oct 27, 2023
1 parent f71f190 commit 347fbfc
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor{
@Setter @Getter
private ScriptService scriptService;

public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) {
this.connector = (AwsConnector) connector;
this.httpClient = httpClient;
}

public AwsConnectorExecutor(Connector connector) {
this.connector = (AwsConnector) connector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor {

private CloseableHttpClient httpClient;

public HttpJsonConnectorExecutor(Connector connector, CloseableHttpClient httpClient) {
this(connector);
this.httpClient = httpClient;
}

public HttpJsonConnectorExecutor(Connector connector) {
this.connector = (HttpConnector)connector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));

AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}
Expand Down Expand Up @@ -139,8 +138,7 @@ public void executePredict_RemoteInferenceInput_InvalidToken() throws IOExceptio
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));

AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}
Expand All @@ -167,8 +165,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));

AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Expand Down Expand Up @@ -201,8 +198,7 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));

AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
Expand Down Expand Up @@ -112,7 +112,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Expand All @@ -137,7 +137,7 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException {
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}
Expand All @@ -159,7 +159,7 @@ public void executePredict_TextDocsInput() throws IOException {
.requestBody("{\"input\": ${parameters.input}}")
.build();
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public class RemoteModelTest {
RemoteModel remoteModel;
Encryptor encryptor;

private Map<String, Object> params = Map.of(RemoteModel.CONNECTION_TIMEOUT, 1000, RemoteModel.READ_TIMEOUT, 1000, RemoteModel.MAX_CONNECTIONS, 30);

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
Expand Down Expand Up @@ -71,7 +73,7 @@ public void predict_ModelDeployed_WrongInput() {
exceptionRule.expectMessage("Wrong input type");
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel.initModel(mlModel, params, encryptor);
remoteModel.predict(mlInput);
}

Expand All @@ -82,14 +84,14 @@ public void initModel_RuntimeException() {
Connector connector = createConnector(null);
when(mlModel.getConnector()).thenReturn(connector);
doThrow(new IllegalArgumentException("Tag mismatch!")).when(encryptor).decrypt(any());
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel.initModel(mlModel, params, encryptor);
}

@Test
public void initModel_NullHeader() {
Connector connector = createConnector(null);
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel.initModel(mlModel, params, encryptor);
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
Assert.assertNull(decryptedHeaders);
}
Expand All @@ -98,7 +100,7 @@ public void initModel_NullHeader() {
public void initModel_WithHeader() {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel.initModel(mlModel, params, encryptor);
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
Assert.assertNotNull(executor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,8 @@ public void deployModel(
connectionTimeoutInMillis,
READ_TIMEOUT,
readTimeoutInMillis,
MAX_CONNECTIONS, maxConnections
MAX_CONNECTIONS,
maxConnections
);
// deploy remote model with internal connector or model trained by built-in algorithm like kmeans
if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,24 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED =
GenerativeQAProcessorConstants.RAG_PIPELINE_FEATURE_ENABLED;

public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND =
Setting.intSetting("plugins.ml_commons.http_client.connection_timeout.in_millisecond", 1000, 1, Setting.Property.NodeScope, Setting.Property.Final);
public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND = Setting
.intSetting(
"plugins.ml_commons.http_client.connection_timeout.in_millisecond",
1000,
1,
Setting.Property.NodeScope,
Setting.Property.Final
);

public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND =
Setting.intSetting("plugins.ml_commons.http_client.read_timeout.in_millisecond", 3000, 1, Setting.Property.NodeScope, Setting.Property.Final);
public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND = Setting
.intSetting(
"plugins.ml_commons.http_client.read_timeout.in_millisecond",
3000,
1,
Setting.Property.NodeScope,
Setting.Property.Final
);

public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_MAX_CONNECTIONS =
Setting.intSetting("plugins.ml_commons.http_client.max_connections", 20, 20, Setting.Property.NodeScope, Setting.Property.Final);
public static final Setting<Integer> ML_COMMONS_HTTP_CLIENT_MAX_CONNECTIONS = Setting
.intSetting("plugins.ml_commons.http_client.max_connections", 20, 20, Setting.Property.NodeScope, Setting.Property.Final);
}

0 comments on commit 347fbfc

Please sign in to comment.