Skip to content

Commit

Permalink
add status code to model tensor (#1443) (#1454)
Browse files Browse the repository at this point in the history
* add status code to model tensor



* fix ut



---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Oct 6, 2023
1 parent 5f92939 commit 156aed7
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -24,7 +25,10 @@
@Getter
public class ModelTensors implements Writeable, ToXContentObject {
public static final String OUTPUT_FIELD = "output";
public static final String STATUS_CODE_FIELD = "status_code";
private List<ModelTensor> mlModelTensors;
@Setter
private Integer statusCode;

@Builder
public ModelTensors(List<ModelTensor> mlModelTensors) {
Expand All @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.endArray();
}
if (statusCode != null) {
builder.field(STATUS_CODE_FIELD, statusCode);
}
builder.endObject();
return builder;
}
Expand All @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException {
mlModelTensors.add(new ModelTensor(in));
}
}
statusCode = in.readOptionalInt();
}

@Override
Expand All @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalInt(statusCode);
}

public void filter(ModelResultFilter resultFilter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> {
return httpClient.prepareRequest(executeRequest).call();
});
int statusCode = response.httpResponse().statusCode();

AbortableInputStream body = null;
if (response.responseBody().isPresent()) {
Expand All @@ -102,6 +103,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
String modelResponse = responseBuilder.toString();

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public HttpJsonConnectorExecutor(Connector connector) {
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
try {
AtomicReference<String> responseRef = new AtomicReference<>("");
AtomicReference<Integer> statusCodeRef = new AtomicReference<>();

HttpUriRequest request;
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
Expand Down Expand Up @@ -97,12 +98,14 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
String responseBody = EntityUtils.toString(responseEntity);
EntityUtils.consume(responseEntity);
responseRef.set(responseBody);
statusCodeRef.set(response.getStatusLine().getStatusCode());
}
return null;
});
String modelResponse = responseRef.get();

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCodeRef.get());
tensorOutputs.add(tensors);
} catch (RuntimeException e) {
log.error("Fail to execute http connector", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand All @@ -32,6 +35,7 @@
import software.amazon.awssdk.http.ExecutableHttpRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
Expand All @@ -41,6 +45,7 @@
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
Expand Down Expand Up @@ -92,6 +97,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
exceptionRule.expectMessage("No response from model");
when(response.responseBody()).thenReturn(Optional.empty());
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
Expand All @@ -116,6 +124,9 @@ public void executePredict_RemoteInferenceInput() throws IOException {
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpRequest.call()).thenReturn(response);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

Expand Down Expand Up @@ -147,6 +158,9 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@

import com.google.common.collect.ImmutableMap;
import org.apache.http.HttpEntity;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
Expand Down Expand Up @@ -87,6 +91,8 @@ public void executePredict_RemoteInferenceInput() throws IOException {
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Expand All @@ -107,6 +113,8 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
Expand Down Expand Up @@ -146,6 +154,8 @@ public void executePredict_TextDocsInput() throws IOException {
+ " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n" + " }\n" + "}";
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
Expand Down

0 comments on commit 156aed7

Please sign in to comment.