Skip to content

Commit

Permalink
add status code to model tensor (#1443) (#1453)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Nov 20, 2023
1 parent 89ea8d0 commit 3495e8d
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -24,7 +25,10 @@
@Getter
public class ModelTensors implements Writeable, ToXContentObject {
public static final String OUTPUT_FIELD = "output";
public static final String STATUS_CODE_FIELD = "status_code";
private List<ModelTensor> mlModelTensors;
@Setter
private Integer statusCode;

@Builder
public ModelTensors(List<ModelTensor> mlModelTensors) {
Expand All @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.endArray();
}
if (statusCode != null) {
builder.field(STATUS_CODE_FIELD, statusCode);
}
builder.endObject();
return builder;
}
Expand All @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException {
mlModelTensors.add(new ModelTensor(in));
}
}
statusCode = in.readOptionalInt();
}

@Override
Expand All @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalInt(statusCode);
}

public void filter(ModelResultFilter resultFilter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> {
return httpClient.prepareRequest(executeRequest).call();
});
int statusCode = response.httpResponse().statusCode();

AbortableInputStream body = null;
if (response.responseBody().isPresent()) {
Expand All @@ -106,6 +107,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
String modelResponse = responseBuilder.toString();

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public HttpJsonConnectorExecutor(Connector connector) {
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
try {
AtomicReference<String> responseRef = new AtomicReference<>("");
AtomicReference<Integer> statusCodeRef = new AtomicReference<>();

HttpUriRequest request;
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
Expand Down Expand Up @@ -98,12 +99,14 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
String responseBody = EntityUtils.toString(responseEntity);
EntityUtils.consume(responseEntity);
responseRef.set(responseBody);
statusCodeRef.set(response.getStatusLine().getStatusCode());
}
return null;
});
String modelResponse = responseRef.get();

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCodeRef.get());
tensorOutputs.add(tensors);
} catch (RuntimeException e) {
log.error("Fail to execute http connector", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,12 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -49,6 +40,23 @@
import software.amazon.awssdk.http.ExecutableHttpRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;

public class AwsConnectorExecutorTest {

Expand Down Expand Up @@ -101,6 +109,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
exceptionRule.expectMessage("No response from model");
when(response.responseBody()).thenReturn(Optional.empty());
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction
Expand Down Expand Up @@ -135,6 +146,9 @@ public void executePredict_RemoteInferenceInput() throws IOException {
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpRequest.call()).thenReturn(response);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

Expand Down Expand Up @@ -177,6 +191,9 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
import java.util.Arrays;

import org.apache.http.HttpEntity;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
Expand Down Expand Up @@ -99,6 +103,8 @@ public void executePredict_RemoteInferenceInput() throws IOException {
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
ModelTensorOutput modelTensorOutput = executor
Expand All @@ -125,13 +131,9 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
Connector connector = HttpConnector
.builder()
.name("test connector")
.version("1")
.protocol("http")
.actions(Arrays.asList(predictAction))
.build();
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
Expand Down Expand Up @@ -174,34 +176,16 @@ public void executePredict_TextDocsInput() throws IOException {
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
String modelResponse = "{\n"
+ " \"object\": \"list\",\n"
+ " \"data\": [\n"
+ " {\n"
+ " \"object\": \"embedding\",\n"
+ " \"index\": 0,\n"
+ " \"embedding\": [\n"
+ " -0.014555434,\n"
+ " -0.002135904,\n"
+ " 0.0035105038\n"
+ " ]\n"
+ " },\n"
+ " {\n"
+ " \"object\": \"embedding\",\n"
+ " \"index\": 1,\n"
+ " \"embedding\": [\n"
+ " -0.014555434,\n"
+ " -0.002135904,\n"
+ " 0.0035105038\n"
+ " ]\n"
+ " }\n"
+ " ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n"
+ " \"usage\": {\n"
+ " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n"
+ " }\n"
+ "}";
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
+ " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n"
+ " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n"
+ " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n"
+ " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n"
+ " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n" + " }\n" + "}";
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
Expand Down

0 comments on commit 3495e8d

Please sign in to comment.