Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to 2.x] add status code to model tensor (#1443) #1453

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -24,7 +25,10 @@
@Getter
public class ModelTensors implements Writeable, ToXContentObject {
public static final String OUTPUT_FIELD = "output";
public static final String STATUS_CODE_FIELD = "status_code";
private List<ModelTensor> mlModelTensors;
@Setter
private Integer statusCode;

@Builder
public ModelTensors(List<ModelTensor> mlModelTensors) {
Expand All @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.endArray();
}
if (statusCode != null) {
builder.field(STATUS_CODE_FIELD, statusCode);
}
builder.endObject();
return builder;
}
Expand All @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException {
mlModelTensors.add(new ModelTensor(in));
}
}
statusCode = in.readOptionalInt();
}

@Override
Expand All @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalInt(statusCode);
}

public void filter(ModelResultFilter resultFilter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> {
return httpClient.prepareRequest(executeRequest).call();
});
int statusCode = response.httpResponse().statusCode();

AbortableInputStream body = null;
if (response.responseBody().isPresent()) {
Expand All @@ -102,6 +103,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
String modelResponse = responseBuilder.toString();

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public HttpJsonConnectorExecutor(Connector connector) {
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
try {
AtomicReference<String> responseRef = new AtomicReference<>("");
AtomicReference<Integer> statusCodeRef = new AtomicReference<>();

HttpUriRequest request;
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
Expand Down Expand Up @@ -97,12 +98,14 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
String responseBody = EntityUtils.toString(responseEntity);
EntityUtils.consume(responseEntity);
responseRef.set(responseBody);
statusCodeRef.set(response.getStatusLine().getStatusCode());
}
return null;
});
String modelResponse = responseRef.get();

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCodeRef.get());
tensorOutputs.add(tensors);
} catch (RuntimeException e) {
log.error("Fail to execute http connector", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand All @@ -32,6 +35,7 @@
import software.amazon.awssdk.http.ExecutableHttpRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
Expand All @@ -41,6 +45,7 @@
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
Expand Down Expand Up @@ -92,6 +97,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
exceptionRule.expectMessage("No response from model");
when(response.responseBody()).thenReturn(Optional.empty());
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
Expand All @@ -116,6 +124,9 @@ public void executePredict_RemoteInferenceInput() throws IOException {
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpRequest.call()).thenReturn(response);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

Expand Down Expand Up @@ -147,6 +158,9 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@

import com.google.common.collect.ImmutableMap;
import org.apache.http.HttpEntity;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
Expand Down Expand Up @@ -87,6 +91,8 @@ public void executePredict_RemoteInferenceInput() throws IOException {
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Expand All @@ -107,6 +113,8 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
Expand Down Expand Up @@ -146,6 +154,8 @@ public void executePredict_TextDocsInput() throws IOException {
+ " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n" + " }\n" + "}";
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
Expand Down