Skip to content

Commit

Permalink
fix cherrypick conflict
Browse files Browse the repository at this point in the history
Signed-off-by: Sicheng Song <sicheng.song@outlook.com>
  • Loading branch information
b4sjoo committed Feb 7, 2024
1 parent 55540c8 commit 6af41f0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
package org.opensearch.ml.engine.algorithms.remote;

import com.google.common.collect.ImmutableMap;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -147,38 +144,4 @@ public void executePredict_RemoteInferenceInput() throws IOException {
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key"));
}

@Test
public void executePredict_TextDocsInferenceInput() throws IOException {
String jsonString = "{\"key\":\"value\"}";
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT)
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.script.ScriptService;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;

import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -100,7 +101,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
}

@Test
public void executePredict_TextDocsInput_NoPreprocessFunction() {
public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input.");
ConnectorAction predictAction = ConnectorAction.builder()
Expand Down Expand Up @@ -154,7 +155,6 @@ public void executePredict_TextDocsInput() throws IOException {
+ " \"total_tokens\": 5\n" + " }\n" + "}";
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
Expand Down

0 comments on commit 6af41f0

Please sign in to comment.