Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support one_to_one in ML Inference Search Response Processor #2801

Merged
merged 6 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add more tests
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
  • Loading branch information
mingshl committed Aug 19, 2024
commit f1aa03f8cb4f3aad925259803402230add7de541
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -164,13 +165,18 @@ public void processResponseAsync(
responseListener,
hits
);

AtomicBoolean isOneHitListenerFailed = new AtomicBoolean(false);
;
for (SearchHit hit : hits) {
SearchHit[] newHits = new SearchHit[1];
newHits[0] = hit;
SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response);
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener);
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed);
rewriteResponseDocuments(oneHitResponse, oneHitListener);
mingshl marked this conversation as resolved.
Show resolved Hide resolved
// if any OneHitListener failure, try stop the rest of the predictions
if (isOneHitListenerFailed.get()) {
break;
mingshl marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

Expand All @@ -189,12 +195,16 @@ public void processResponseAsync(
* onResponse and onFailure callbacks to a GroupedActionListener.
*
* @param combineResponseListener The GroupedActionListener to which the
* onResponse and onFailure callbacks will be
* delegated.
* onResponse and onFailure callbacks will be
* delegated.
* @param isOneHitListenerFailed
* @return An ActionListener that delegates its callbacks to the provided
* GroupedActionListener.
* GroupedActionListener.
*/
private static ActionListener<SearchResponse> getOneHitListener(GroupedActionListener<SearchResponse> combineResponseListener) {
private static ActionListener<SearchResponse> getOneHitListener(
GroupedActionListener<SearchResponse> combineResponseListener,
AtomicBoolean isOneHitListenerFailed
) {
ActionListener<SearchResponse> oneHitListener = new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
Expand All @@ -203,6 +213,8 @@ public void onResponse(SearchResponse response) {

@Override
public void onFailure(Exception e) {
// if any OneHitListener failure, try stop the rest of the predictions and return
isOneHitListenerFailed.compareAndSet(false, true);
combineResponseListener.onFailure(e);
}
};
Expand Down Expand Up @@ -239,11 +251,11 @@ public void onResponse(Collection<SearchResponse> responseMapCollection) {

@Override
public void onFailure(Exception e) {
// when one hit failed and ignoreFailure return original response
if (ignoreFailure) {
responseListener.onResponse(response);
} else {
responseListener.onFailure(e);
responseListener
.onFailure(new OpenSearchStatusException("Failed to process response: " + e.getMessage(), RestStatus.BAD_REQUEST));
}
}
}, hits.length);
Expand All @@ -265,21 +277,21 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
// hitCountInPredictions keeps track of the count of hit that have the required input fields for each round of prediction
Map<Integer, Integer> hitCountInPredictions = new HashMap<>();

ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = createRewriteResponseListenerManyToOne(
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = createRewriteResponseListener(
response,
responseListener,
processInputMap,
processOutputMap,
hitCountInPredictions
);

GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListenerManyToOne(
GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = createBatchPredictionListener(
rewriteResponseListener,
inputMapSize
);
SearchHit[] hits = response.getHits().getHits();
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
processPredictionsManyToOne(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
}
}

Expand All @@ -293,7 +305,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @throws IOException if an I/O error occurs during the prediction process
*/
private void processPredictionsManyToOne(
private void processPredictions(
SearchHit[] hits,
List<Map<String, String>> processInputMap,
int inputMapIndex,
Expand Down Expand Up @@ -356,7 +368,6 @@ private void processPredictionsManyToOne(

// when not existed in the map, add into the modelInputParameters map
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);

}
}
}
Expand Down Expand Up @@ -430,7 +441,7 @@ private void updateModelInputParameters(Map<String, Object> modelInputParameters
* @param inputMapSize the size of the input map
* @return a grouped action listener for batch predictions
*/
private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListenerManyToOne(
private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListener(
ActionListener<Map<Integer, MLOutput>> rewriteResponseListener,
int inputMapSize
) {
Expand Down Expand Up @@ -462,7 +473,7 @@ public void onFailure(Exception e) {
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @return an action listener for rewriting the response with the inference results
*/
private ActionListener<Map<Integer, MLOutput>> createRewriteResponseListenerManyToOne(
private ActionListener<Map<Integer, MLOutput>> createRewriteResponseListener(
SearchResponse response,
ActionListener<SearchResponse> responseListener,
List<Map<String, String>> processInputMap,
Expand Down Expand Up @@ -494,7 +505,7 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
Map<String, String> outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap);

boolean isModelInputMissing = false;
if (processInputMap != null) {
if (processInputMap != null && !processInputMap.isEmpty()) {
isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
}
if (!isModelInputMissing) {
Expand Down
Loading
Loading