Skip to content

Commit 56d5802

Browse files
ylwu-amznrbhavna
authored andcommitted
add status code to model tensor (opensearch-project#1443) (opensearch-project#1453)
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent 96b0f94 commit 56d5802

File tree

5 files changed

+64
-49
lines changed

5 files changed

+64
-49
lines changed

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import lombok.Builder;
99
import lombok.Getter;
10+
import lombok.Setter;
1011
import org.opensearch.common.io.stream.BytesStreamOutput;
1112
import org.opensearch.core.common.bytes.BytesReference;
1213
import org.opensearch.core.common.io.stream.StreamInput;
@@ -24,7 +25,10 @@
2425
@Getter
2526
public class ModelTensors implements Writeable, ToXContentObject {
2627
public static final String OUTPUT_FIELD = "output";
28+
public static final String STATUS_CODE_FIELD = "status_code";
2729
private List<ModelTensor> mlModelTensors;
30+
@Setter
31+
private Integer statusCode;
2832

2933
@Builder
3034
public ModelTensors(List<ModelTensor> mlModelTensors) {
@@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4145
}
4246
builder.endArray();
4347
}
48+
if (statusCode != null) {
49+
builder.field(STATUS_CODE_FIELD, statusCode);
50+
}
4451
builder.endObject();
4552
return builder;
4653
}
@@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException {
5360
mlModelTensors.add(new ModelTensor(in));
5461
}
5562
}
63+
statusCode = in.readOptionalInt();
5664
}
5765

5866
@Override
@@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException {
6674
} else {
6775
out.writeBoolean(false);
6876
}
77+
out.writeOptionalInt(statusCode);
6978
}
7079

7180
public void filter(ModelResultFilter resultFilter) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
8686
HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> {
8787
return httpClient.prepareRequest(executeRequest).call();
8888
});
89+
int statusCode = response.httpResponse().statusCode();
8990

9091
AbortableInputStream body = null;
9192
if (response.responseBody().isPresent()) {
@@ -106,6 +107,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
106107
String modelResponse = responseBuilder.toString();
107108

108109
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
110+
tensors.setStatusCode(statusCode);
109111
tensorOutputs.add(tensors);
110112
} catch (RuntimeException exception) {
111113
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public HttpJsonConnectorExecutor(Connector connector) {
5454
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
5555
try {
5656
AtomicReference<String> responseRef = new AtomicReference<>("");
57+
AtomicReference<Integer> statusCodeRef = new AtomicReference<>();
5758

5859
HttpUriRequest request;
5960
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
@@ -98,12 +99,14 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
9899
String responseBody = EntityUtils.toString(responseEntity);
99100
EntityUtils.consume(responseEntity);
100101
responseRef.set(responseBody);
102+
statusCodeRef.set(response.getStatusLine().getStatusCode());
101103
}
102104
return null;
103105
});
104106
String modelResponse = responseRef.get();
105107

106108
ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
109+
tensors.setStatusCode(statusCodeRef.get());
107110
tensorOutputs.add(tensors);
108111
} catch (RuntimeException e) {
109112
log.error("Fail to execute http connector", e);

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,12 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8-
import static org.mockito.ArgumentMatchers.any;
9-
import static org.mockito.Mockito.spy;
10-
import static org.mockito.Mockito.when;
11-
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
12-
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
13-
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
14-
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
15-
16-
import java.io.ByteArrayInputStream;
17-
import java.io.IOException;
18-
import java.io.InputStream;
19-
import java.util.Arrays;
20-
import java.util.Map;
21-
import java.util.Optional;
228

9+
import com.google.common.collect.ImmutableList;
10+
import com.google.common.collect.ImmutableMap;
11+
import org.apache.http.ProtocolVersion;
12+
import org.apache.http.StatusLine;
13+
import org.apache.http.message.BasicStatusLine;
2314
import org.junit.Assert;
2415
import org.junit.Before;
2516
import org.junit.Rule;
@@ -49,6 +40,23 @@
4940
import software.amazon.awssdk.http.ExecutableHttpRequest;
5041
import software.amazon.awssdk.http.HttpExecuteResponse;
5142
import software.amazon.awssdk.http.SdkHttpClient;
43+
import software.amazon.awssdk.http.SdkHttpResponse;
44+
45+
import java.io.ByteArrayInputStream;
46+
import java.io.IOException;
47+
import java.io.InputStream;
48+
import java.util.Arrays;
49+
import java.util.Map;
50+
import java.util.Optional;
51+
52+
import static org.mockito.ArgumentMatchers.any;
53+
import static org.mockito.Mockito.mock;
54+
import static org.mockito.Mockito.spy;
55+
import static org.mockito.Mockito.when;
56+
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
57+
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
58+
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
59+
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
5260

5361
public class AwsConnectorExecutorTest {
5462

@@ -101,6 +109,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
101109
exceptionRule.expectMessage("No response from model");
102110
when(response.responseBody()).thenReturn(Optional.empty());
103111
when(httpRequest.call()).thenReturn(response);
112+
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
113+
when(httpResponse.statusCode()).thenReturn(200);
114+
when(response.httpResponse()).thenReturn(httpResponse);
104115
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
105116

106117
ConnectorAction predictAction = ConnectorAction
@@ -135,6 +146,9 @@ public void executePredict_RemoteInferenceInput() throws IOException {
135146
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
136147
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
137148
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
149+
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
150+
when(httpResponse.statusCode()).thenReturn(200);
151+
when(response.httpResponse()).thenReturn(httpResponse);
138152
when(httpRequest.call()).thenReturn(response);
139153
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
140154

@@ -177,6 +191,9 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
177191
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
178192
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
179193
when(httpRequest.call()).thenReturn(response);
194+
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
195+
when(httpResponse.statusCode()).thenReturn(200);
196+
when(response.httpResponse()).thenReturn(httpResponse);
180197
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);
181198

182199
ConnectorAction predictAction = ConnectorAction

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
import java.util.Arrays;
1414

1515
import org.apache.http.HttpEntity;
16+
import org.apache.http.ProtocolVersion;
17+
import org.apache.http.StatusLine;
1618
import org.apache.http.client.methods.CloseableHttpResponse;
1719
import org.apache.http.entity.StringEntity;
1820
import org.apache.http.impl.client.CloseableHttpClient;
21+
import org.apache.http.message.BasicStatusLine;
1922
import org.junit.Assert;
2023
import org.junit.Before;
2124
import org.junit.Rule;
2225
import org.junit.Test;
2326
import org.junit.rules.ExpectedException;
2427
import org.mockito.Mock;
2528
import org.mockito.MockitoAnnotations;
29+
import org.opensearch.cluster.ClusterStateTaskConfig;
2630
import org.opensearch.ingest.TestTemplateService;
2731
import org.opensearch.ml.common.FunctionName;
2832
import org.opensearch.ml.common.connector.Connector;
@@ -99,6 +103,8 @@ public void executePredict_RemoteInferenceInput() throws IOException {
99103
when(httpClient.execute(any())).thenReturn(response);
100104
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
101105
when(response.getEntity()).thenReturn(entity);
106+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
107+
when(response.getStatusLine()).thenReturn(statusLine);
102108
when(executor.getHttpClient()).thenReturn(httpClient);
103109
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
104110
ModelTensorOutput modelTensorOutput = executor
@@ -125,13 +131,9 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
125131
when(httpClient.execute(any())).thenReturn(response);
126132
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
127133
when(response.getEntity()).thenReturn(entity);
128-
Connector connector = HttpConnector
129-
.builder()
130-
.name("test connector")
131-
.version("1")
132-
.protocol("http")
133-
.actions(Arrays.asList(predictAction))
134-
.build();
134+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
135+
when(response.getStatusLine()).thenReturn(statusLine);
136+
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
135137
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
136138
when(executor.getHttpClient()).thenReturn(httpClient);
137139
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
@@ -174,34 +176,16 @@ public void executePredict_TextDocsInput() throws IOException {
174176
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
175177
executor.setScriptService(scriptService);
176178
when(httpClient.execute(any())).thenReturn(response);
177-
String modelResponse = "{\n"
178-
+ " \"object\": \"list\",\n"
179-
+ " \"data\": [\n"
180-
+ " {\n"
181-
+ " \"object\": \"embedding\",\n"
182-
+ " \"index\": 0,\n"
183-
+ " \"embedding\": [\n"
184-
+ " -0.014555434,\n"
185-
+ " -0.002135904,\n"
186-
+ " 0.0035105038\n"
187-
+ " ]\n"
188-
+ " },\n"
189-
+ " {\n"
190-
+ " \"object\": \"embedding\",\n"
191-
+ " \"index\": 1,\n"
192-
+ " \"embedding\": [\n"
193-
+ " -0.014555434,\n"
194-
+ " -0.002135904,\n"
195-
+ " 0.0035105038\n"
196-
+ " ]\n"
197-
+ " }\n"
198-
+ " ],\n"
199-
+ " \"model\": \"text-embedding-ada-002-v2\",\n"
200-
+ " \"usage\": {\n"
201-
+ " \"prompt_tokens\": 5,\n"
202-
+ " \"total_tokens\": 5\n"
203-
+ " }\n"
204-
+ "}";
179+
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
180+
+ " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n"
181+
+ " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n"
182+
+ " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n"
183+
+ " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n"
184+
+ " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n"
185+
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
186+
+ " \"total_tokens\": 5\n" + " }\n" + "}";
187+
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
188+
when(response.getStatusLine()).thenReturn(statusLine);
205189
HttpEntity entity = new StringEntity(modelResponse);
206190
when(response.getEntity()).thenReturn(entity);
207191
when(executor.getHttpClient()).thenReturn(httpClient);

0 commit comments

Comments
 (0)