Skip to content

Commit a9151f4

Browse files
authored
Add ml-commons passthrough post process function (opensearch-project#4111)
* Add ml commons passthrough post process function Signed-off-by: Andy Qin <qinandy@amazon.com> * Apply spotless Signed-off-by: Andy Qin <qinandy@amazon.com> * Add more comments and refactor Signed-off-by: Andy Qin <qinandy@amazon.com> --------- Signed-off-by: Andy Qin <qinandy@amazon.com>
1 parent d077f31 commit a9151f4

File tree

3 files changed

+392
-0
lines changed

3 files changed

+392
-0
lines changed

common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.opensearch.ml.common.connector.functions.postprocess.BedrockRerankPostProcessFunction;
1616
import org.opensearch.ml.common.connector.functions.postprocess.CohereRerankPostProcessFunction;
1717
import org.opensearch.ml.common.connector.functions.postprocess.EmbeddingPostProcessFunction;
18+
import org.opensearch.ml.common.connector.functions.postprocess.RemoteMlCommonsPassthroughPostProcessFunction;
1819
import org.opensearch.ml.common.output.model.MLResultDataType;
1920
import org.opensearch.ml.common.output.model.ModelTensor;
2021

@@ -35,6 +36,8 @@ public class MLPostProcessFunction {
3536
public static final String BEDROCK_RERANK = "connector.post_process.bedrock.rerank";
3637
public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding";
3738
public static final String DEFAULT_RERANK = "connector.post_process.default.rerank";
39+
// ML commons passthrough unwraps a remote ml-commons response and reconstructs model tensors directly based on remote inference
40+
public static final String ML_COMMONS_PASSTHROUGH = "connector.post_process.mlcommons.passthrough";
3841

3942
private static final Map<String, String> JSON_PATH_EXPRESSION = new HashMap<>();
4043

@@ -46,6 +49,8 @@ public class MLPostProcessFunction {
4649
BedrockBatchJobArnPostProcessFunction batchJobArnPostProcessFunction = new BedrockBatchJobArnPostProcessFunction();
4750
CohereRerankPostProcessFunction cohereRerankPostProcessFunction = new CohereRerankPostProcessFunction();
4851
BedrockRerankPostProcessFunction bedrockRerankPostProcessFunction = new BedrockRerankPostProcessFunction();
52+
RemoteMlCommonsPassthroughPostProcessFunction remoteMlCommonsPassthroughPostProcessFunction =
53+
new RemoteMlCommonsPassthroughPostProcessFunction();
4954
JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding");
5055
JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings");
5156
JSON_PATH_EXPRESSION.put(COHERE_V2_EMBEDDING_FLOAT32, "$.embeddings.float");
@@ -61,6 +66,7 @@ public class MLPostProcessFunction {
6166
JSON_PATH_EXPRESSION.put(COHERE_RERANK, "$.results");
6267
JSON_PATH_EXPRESSION.put(BEDROCK_RERANK, "$.results");
6368
JSON_PATH_EXPRESSION.put(DEFAULT_RERANK, "$[*]");
69+
JSON_PATH_EXPRESSION.put(ML_COMMONS_PASSTHROUGH, "$"); // Get the entire response
6470
POST_PROCESS_FUNCTIONS.put(OPENAI_EMBEDDING, embeddingPostProcessFunction);
6571
POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, embeddingPostProcessFunction);
6672
POST_PROCESS_FUNCTIONS.put(COHERE_V2_EMBEDDING_FLOAT32, embeddingPostProcessFunction);
@@ -76,6 +82,7 @@ public class MLPostProcessFunction {
7682
POST_PROCESS_FUNCTIONS.put(COHERE_RERANK, cohereRerankPostProcessFunction);
7783
POST_PROCESS_FUNCTIONS.put(BEDROCK_RERANK, bedrockRerankPostProcessFunction);
7884
POST_PROCESS_FUNCTIONS.put(DEFAULT_RERANK, cohereRerankPostProcessFunction);
85+
POST_PROCESS_FUNCTIONS.put(ML_COMMONS_PASSTHROUGH, remoteMlCommonsPassthroughPostProcessFunction);
7986
}
8087

8188
public static String getResponseFilter(String postProcessFunction) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.postprocess;
7+
8+
import static org.opensearch.ml.common.output.model.ModelTensors.OUTPUT_FIELD;
9+
10+
import java.lang.reflect.Array;
11+
import java.util.ArrayList;
12+
import java.util.Arrays;
13+
import java.util.List;
14+
import java.util.Map;
15+
16+
import org.opensearch.ml.common.output.model.MLResultDataType;
17+
import org.opensearch.ml.common.output.model.ModelTensor;
18+
19+
/**
20+
* A post-processing function for calling a remote ml commons instance that preserves the original neural sparse response structure
21+
* to avoid double-wrapping when receiving responses from another ML-Commons instance.
22+
*/
23+
public class RemoteMlCommonsPassthroughPostProcessFunction extends ConnectorPostProcessFunction<Map<String, Object>> {
24+
@Override
25+
public void validate(Object input) {
26+
if (!(input instanceof Map) && !(input instanceof List)) {
27+
throw new IllegalArgumentException("Post process function input must be a Map or List");
28+
}
29+
}
30+
31+
/**
32+
* Example unwrapped response:
33+
* {
34+
* "inference_results": [
35+
* {
36+
* "output": [
37+
* {
38+
* "name": "output",
39+
* "dataAsMap": {
40+
* "inference_results": [
41+
* {
42+
* "output": [
43+
* {
44+
* "name": "output",
45+
* "dataAsMap": {
46+
* "response": [
47+
* {
48+
* "increasingly": 0.028670792,
49+
* "achievements": 0.4906937,
50+
* ...
51+
* }
52+
* ]
53+
* }
54+
* }
55+
* ],
56+
* "status_code": 200.0
57+
* }
58+
* ]
59+
* }
60+
* }
61+
* ],
62+
* "status_code": 200
63+
* }
64+
* ]
65+
* }
66+
*
67+
* Example unwrapped response:
68+
*
69+
* {
70+
* "inference_results": [
71+
* {
72+
* "output": [
73+
* {
74+
* "name": "output",
75+
* "dataAsMap": {
76+
* "response": [
77+
* {
78+
* "increasingly": 0.028670792,
79+
* "achievements": 0.4906937,
80+
* ...
81+
* }
82+
* ]
83+
* }
84+
* },
85+
* ],
86+
* "status_code": 200
87+
* }
88+
* ]
89+
* }
90+
*
91+
* @param mlCommonsResponse raw remote ml commons response
92+
* @param dataType the datatype of the result, not used since datatype is set based on the response body
93+
* @return a list of model tensors representing the inner model tensors
94+
*/
95+
@Override
96+
public List<ModelTensor> process(Map<String, Object> mlCommonsResponse, MLResultDataType dataType) {
97+
// Check if this is an ML-Commons response with inference_results
98+
if (mlCommonsResponse.containsKey("inference_results") && mlCommonsResponse.get("inference_results") instanceof List) {
99+
List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) mlCommonsResponse.get("inference_results");
100+
101+
List<ModelTensor> modelTensors = new ArrayList<>();
102+
for (Map<String, Object> result : inferenceResults) {
103+
// Extract the output field which contains the ModelTensor data
104+
if (result.containsKey("output") && result.get("output") instanceof List) {
105+
List<Map<String, Object>> outputs = (List<Map<String, Object>>) result.get("output");
106+
for (Map<String, Object> output : outputs) {
107+
// This inner map should represent a model tensor, so we try to parse and instantiate a new one.
108+
ModelTensor modelTensor = createModelTensorFromMap(output);
109+
if (modelTensor != null) {
110+
modelTensors.add(modelTensor);
111+
}
112+
}
113+
}
114+
}
115+
116+
return modelTensors;
117+
}
118+
119+
// Fallback for non-ML-Commons responses
120+
ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(mlCommonsResponse).build();
121+
122+
return List.of(tensor);
123+
}
124+
125+
/**
126+
* Creates a ModelTensor from a Map<String, Object> representation based on the API format
127+
* of the /_predict API
128+
*/
129+
private ModelTensor createModelTensorFromMap(Map<String, Object> map) {
130+
if (map == null || map.isEmpty()) {
131+
return null;
132+
}
133+
134+
// Get name. If name is null or not a String, default to OUTPUT_FIELD
135+
Object uncastedName = map.get(ModelTensor.NAME_FIELD);
136+
String name = uncastedName instanceof String castedName ? castedName : OUTPUT_FIELD;
137+
String result = (String) map.get(ModelTensor.RESULT_FIELD);
138+
139+
// Handle data as map
140+
Map<String, Object> dataAsMap = (Map<String, Object>) map.get(ModelTensor.DATA_AS_MAP_FIELD);
141+
142+
// Handle data type. For certain models like neural sparse and non-dense remote models, this field
143+
// is not populated and left as null instead, which is still valid
144+
MLResultDataType dataType = null;
145+
if (map.containsKey(ModelTensor.DATA_TYPE_FIELD)) {
146+
Object dataTypeObj = map.get(ModelTensor.DATA_TYPE_FIELD);
147+
if (dataTypeObj instanceof String) {
148+
try {
149+
dataType = MLResultDataType.valueOf((String) dataTypeObj);
150+
} catch (IllegalArgumentException e) {
151+
// Invalid data type, leave as null in case inner data is still useful to be parsed in the future
152+
}
153+
}
154+
}
155+
156+
// Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result
157+
// is stored in dataAsMap, not data/shape field
158+
long[] shape = null;
159+
if (map.containsKey(ModelTensor.SHAPE_FIELD)) {
160+
Number[] numbers = processNumericalArray(map, ModelTensor.SHAPE_FIELD, Number.class);
161+
if (numbers != null) {
162+
shape = Arrays.stream(numbers).mapToLong(Number::longValue).toArray();
163+
}
164+
}
165+
166+
// Handle shape. For certain models like neural sparse and non-dense, null is valid since inference result
167+
// is stored in dataAsMap, not data/shape field
168+
Number[] data = null;
169+
if (map.containsKey(ModelTensor.DATA_FIELD)) {
170+
data = processNumericalArray(map, ModelTensor.DATA_FIELD, Number.class);
171+
}
172+
173+
// For now, we skip handling byte buffer since it's not needed for neural sparse and dense model use cases.
174+
175+
return ModelTensor.builder().name(name).dataType(dataType).shape(shape).data(data).result(result).dataAsMap(dataAsMap).build();
176+
}
177+
178+
private static <T> T[] processNumericalArray(Map<String, Object> map, String key, Class<T> type) {
179+
Object obj = map.get(key);
180+
if (obj instanceof List<?> list) {
181+
T[] array = (T[]) Array.newInstance(type, list.size());
182+
for (int i = 0; i < list.size(); i++) {
183+
Object item = list.get(i);
184+
if (type.isInstance(item)) {
185+
array[i] = type.cast(item);
186+
}
187+
}
188+
return array;
189+
}
190+
return null;
191+
}
192+
}

0 commit comments

Comments
 (0)