Skip to content

Commit

Permalink
add one document to one prediction support
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
  • Loading branch information
mingshl committed Aug 2, 2024
1 parent a436a94 commit b520539
Show file tree
Hide file tree
Showing 3 changed files with 456 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.utils.MapUtils;
import org.opensearch.ml.utils.SearchResponseUtil;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.PipelineProcessingContext;
Expand Down Expand Up @@ -125,9 +126,16 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
/**
* Processes the search response asynchronously by rewriting the documents with the inference results.
*
* @param request the search request
* @param response the search response
* @param responseContext the pipeline processing context
* At default, process many document in one prediction through rewriteResponseDocuments method.
* but when process inference one document for one inference,
* separate one N-hits searchResponse into N one-hit search response,
* execute the same rewriteResponseDocument method,
* after N one-hit search response with inference result gets back,
* combined N one-hit search response back into one N-hits searchResponse.
*
* @param request the search request
* @param response the search response
* @param responseContext the pipeline processing context
* @param responseListener the listener to be notified when the response is processed
*/
@Override
Expand All @@ -144,20 +152,106 @@ public void processResponseAsync(
responseListener.onResponse(response);
return;
}
rewriteResponseDocuments(response, responseListener);

// if many to one, run rewriteResponseDocuments
if (!oneToOne) {
rewriteResponseDocuments(response, responseListener);
} else {
// if one to one, make one hit search response and run rewriteResponseDocuments
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
response,
responseListener,
hits
);

for (SearchHit hit : hits) {
SearchHit[] newHits = new SearchHit[1];
newHits[0] = hit;
SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response);
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener);
rewriteResponseDocuments(oneHitResponse, oneHitListener);
}
}

} catch (Exception e) {
if (ignoreFailure) {
responseListener.onResponse(response);
} else {
responseListener.onFailure(e);
responseListener.onFailure(new RuntimeException(e.getMessage()));
}
}
}

/**
* Creates an ActionListener for a single SearchResponse that delegates its
* onResponse and onFailure callbacks to a GroupedActionListener.
*
* @param combineResponseListener The GroupedActionListener to which the
* onResponse and onFailure callbacks will be
* delegated.
* @return An ActionListener that delegates its callbacks to the provided
* GroupedActionListener.
*/
private static ActionListener<SearchResponse> getOneHitListener(GroupedActionListener<SearchResponse> combineResponseListener) {
ActionListener<SearchResponse> oneHitListener = new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
combineResponseListener.onResponse(response);
}

@Override
public void onFailure(Exception e) {
combineResponseListener.onFailure(e);
}
};
return oneHitListener;
}

/**
* Creates a GroupedActionListener that combines the SearchResponses from individual hits
* and constructs a new SearchResponse with the combined hits.
*
* @param response The original SearchResponse containing the hits to be processed.
* @param responseListener The ActionListener to be notified with the combined SearchResponse.
* @param hits The array of SearchHits to be processed.
* @return A GroupedActionListener that combines the SearchResponses and constructs a new SearchResponse.
*/
private GroupedActionListener<SearchResponse> getCombineResponseGroupedActionListener(
SearchResponse response,
ActionListener<SearchResponse> responseListener,
SearchHit[] hits
) {
GroupedActionListener<SearchResponse> combineResponseListener = new GroupedActionListener<>(new ActionListener<>() {
@Override
public void onResponse(Collection<SearchResponse> responseMapCollection) {
SearchHit[] combinedHits = new SearchHit[hits.length];
int i = 0;
for (SearchResponse OneHitResponseAfterInference : responseMapCollection) {
SearchHit[] hitsAfterInference = OneHitResponseAfterInference.getHits().getHits();
combinedHits[i] = hitsAfterInference[0];
i++;
}
SearchResponse oneToOneInferenceSearchResponse = SearchResponseUtil.replaceHits(combinedHits, response);
responseListener.onResponse(oneToOneInferenceSearchResponse);
}

@Override
public void onFailure(Exception e) {
// when one hit failed and ignoreFailure return original response
if (ignoreFailure) {
responseListener.onResponse(response);
} else {
responseListener.onFailure(e);
}
}
}, hits.length);
return combineResponseListener;
}

/**
* Rewrite the documents in the search response with the inference results.
*
* @param response the search response
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @throws IOException if an I/O error occurs during the rewriting process
*/
Expand All @@ -168,27 +262,23 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se

// hitCountInPredictions keeps track of the count of hit that have the required input fields for each round of prediction
Map<Integer, Integer> hitCountInPredictions = new HashMap<>();
if (!oneToOne) {
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = createRewriteResponseListenerManyToOne(
response,
responseListener,
processInputMap,
processOutputMap,
hitCountInPredictions
);

GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListenerManyToOne(
rewriteResponseListener,
inputMapSize
);
SearchHit[] hits = response.getHits().getHits();
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
processPredictionsManyToOne(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
}
} else {
responseListener.onFailure(new IllegalArgumentException("one to one prediction is not supported yet."));
}
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = createRewriteResponseListenerManyToOne(
response,
responseListener,
processInputMap,
processOutputMap,
hitCountInPredictions
);

GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListenerManyToOne(
rewriteResponseListener,
inputMapSize
);
SearchHit[] hits = response.getHits().getHits();
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
processPredictionsManyToOne(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
}
}

/**
Expand Down Expand Up @@ -242,7 +332,7 @@ private void processPredictionsManyToOne(
Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName);
if (documentValue != null) {
// when not existed in the map, add into the modelInputParameters map
updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue);
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
}
}
} else { // when document does not contain the documentFieldName, skip when ignoreMissing
Expand All @@ -263,7 +353,7 @@ private void processPredictionsManyToOne(
Object documentValue = entry.getValue();

// when not existed in the map, add into the modelInputParameters map
updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue);
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);

}
}
Expand Down Expand Up @@ -306,18 +396,28 @@ public void onFailure(Exception e) {
});
}

private void updateModelInputParametersManyToOne(
Map<String, Object> modelInputParameters,
String modelInputFieldName,
Object documentValue
) {
if (!modelInputParameters.containsKey(modelInputFieldName)) {
List<Object> documentValueList = new ArrayList<>();
documentValueList.add(documentValue);
modelInputParameters.put(modelInputFieldName, documentValueList);
/**
* Updates the model input parameters map with the given document value.
* If the setting is one-to-one,
* simply put the document value in the map
* If the setting is many-to-one,
* create a new list and add the document value
* @param modelInputParameters The map containing the model input parameters.
* @param modelInputFieldName The name of the model input field.
* @param documentValue The value from the document that needs to be added to the model input parameters.
*/
private void updateModelInputParameters(Map<String, Object> modelInputParameters, String modelInputFieldName, Object documentValue) {
if (!this.oneToOne) {
if (!modelInputParameters.containsKey(modelInputFieldName)) {
List<Object> documentValueList = new ArrayList<>();
documentValueList.add(documentValue);
modelInputParameters.put(modelInputFieldName, documentValueList);
} else {
List<Object> valueList = ((List) modelInputParameters.get(modelInputFieldName));
valueList.add(documentValue);
}
} else {
List<Object> valueList = ((List) modelInputParameters.get(modelInputFieldName));
valueList.add(documentValue);
modelInputParameters.put(modelInputFieldName, documentValue);
}
}

Expand Down Expand Up @@ -353,11 +453,11 @@ public void onFailure(Exception e) {
/**
* Creates an action listener for rewriting the response with the inference results.
*
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @param processInputMap the list of input mappings
* @param processOutputMap the list of output mappings
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @param processInputMap the list of input mappings
* @param processOutputMap the list of output mappings
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @return an action listener for rewriting the response with the inference results
*/
private ActionListener<Map<Integer, MLOutput>> createRewriteResponseListenerManyToOne(
Expand Down Expand Up @@ -499,10 +599,10 @@ private boolean checkIsModelInputMissing(Map<String, Object> document, Map<Strin
* <p>If the processOutputMap is not null and not empty, the mapping at the specified mappingIndex
* is returned.
*
* @param mappingIndex the index of the mapping to retrieve from the processOutputMap
* @param mappingIndex the index of the mapping to retrieve from the processOutputMap
* @param processOutputMap the list of output mappings, can be null or empty
* @return a Map containing the output mapping, either the default mapping or the mapping at the
* specified index
* specified index
*/
private static Map<String, String> getDefaultOutputMapping(Integer mappingIndex, List<Map<String, String>> processOutputMap) {
Map<String, String> outputMapping;
Expand All @@ -524,11 +624,11 @@ private static Map<String, String> getDefaultOutputMapping(Integer mappingIndex,
* <p>If the processInputMap is not null and not empty, the mapping at the specified mappingIndex
* is returned.
*
* @param sourceAsMap the source map containing the input data
* @param mappingIndex the index of the mapping to retrieve from the processInputMap
* @param sourceAsMap the source map containing the input data
* @param mappingIndex the index of the mapping to retrieve from the processInputMap
* @param processInputMap the list of input mappings, can be null or empty
* @return a Map containing the input mapping, either the mapping extracted from sourceAsMap or
* the mapping at the specified index
* the mapping at the specified index
*/
private static Map<String, String> getDefaultInputMapping(
Map<String, Object> sourceAsMap,
Expand Down
Loading

0 comments on commit b520539

Please sign in to comment.