-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Backport 2.16] add initial MLInferenceSearchResponseProcessor (#2734)
* add initial MLInferenceSearchResponseProcessor (#2688) * add MLInferenceSearchResponseProcessor Signed-off-by: Mingshi Liu <mingshl@amazon.com> * add ITs Signed-off-by: Mingshi Liu <mingshl@amazon.com> * add code coverage Signed-off-by: Mingshi Liu <mingshl@amazon.com> * add many_to_one flag Signed-off-by: Mingshi Liu <mingshl@amazon.com> * avoid import * Signed-off-by: Mingshi Liu <mingshl@amazon.com> * remove extra hits Signed-off-by: Mingshi Liu <mingshl@amazon.com> * spotlessApply Signed-off-by: Mingshi Liu <mingshl@amazon.com> * remove extra brackets Signed-off-by: Mingshi Liu <mingshl@amazon.com> --------- Signed-off-by: Mingshi Liu <mingshl@amazon.com> (cherry picked from commit 01084b4) * fix http package Signed-off-by: Mingshi Liu <mingshl@amazon.com> --------- Signed-off-by: Mingshi Liu <mingshl@amazon.com> Co-authored-by: Mingshi Liu <mingshl@amazon.com>
- Loading branch information
1 parent
ca6bbe7
commit 23bacd3
Showing
10 changed files
with
2,871 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
674 changes: 674 additions & 0 deletions
674
plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
plugin/src/main/java/org/opensearch/ml/utils/MapUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.utils; | ||
|
||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
public class MapUtils { | ||
|
||
/** | ||
* Increments the counter for the given key in the specified version. | ||
* If the key doesn't exist, it initializes the counter to 0. | ||
* | ||
* @param version the version of the counter | ||
* @param key the key for which the counter needs to be incremented | ||
*/ | ||
public static void incrementCounter(Map<Integer, Map<String, Integer>> versionedCounters, int version, String key) { | ||
Map<String, Integer> counters = versionedCounters.computeIfAbsent(version, k -> new HashMap<>()); | ||
counters.put(key, counters.getOrDefault(key, -1) + 1); | ||
} | ||
|
||
/** | ||
* Retrieves the counter value for the given key in the specified version. | ||
* If the key doesn't exist, it returns 0. | ||
* | ||
* @param version the version of the counter | ||
* @param key the key for which the counter needs to be retrieved | ||
* @return the counter value for the given key | ||
*/ | ||
public static int getCounter(Map<Integer, Map<String, Integer>> versionedCounters, int version, String key) { | ||
Map<String, Integer> counters = versionedCounters.get(version); | ||
return counters != null ? counters.getOrDefault(key, -1) : 0; | ||
} | ||
|
||
/** | ||
* Increments the counter value for the given key in the provided counters map. | ||
* If the key does not exist in the map, it is added with an initial counter value of 0. | ||
* | ||
* @param counters A map that stores integer counters for each integer key. | ||
* @param key The integer key for which the counter needs to be incremented. | ||
*/ | ||
public static void incrementCounter(Map<Integer, Integer> counters, int key) { | ||
counters.put(key, counters.getOrDefault(key, 0) + 1); | ||
} | ||
|
||
public static int getCounter(Map<Integer, Integer> counters, int key) { | ||
return counters.getOrDefault(key, 0); | ||
} | ||
|
||
} |
85 changes: 85 additions & 0 deletions
85
plugin/src/main/java/org/opensearch/ml/utils/SearchResponseUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.utils; | ||
|
||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.action.search.SearchResponseSections; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.SearchHits; | ||
import org.opensearch.search.aggregations.InternalAggregations; | ||
import org.opensearch.search.internal.InternalSearchResponse; | ||
import org.opensearch.search.profile.SearchProfileShardResults; | ||
|
||
public class SearchResponseUtil { | ||
private SearchResponseUtil() {} | ||
|
||
/** | ||
* Construct a new {@link SearchResponse} based on an existing one, replacing just the {@link SearchHits}. | ||
* @param newHits new {@link SearchHits} | ||
* @param response the existing search response | ||
* @return a new search response where the {@link SearchHits} has been replaced | ||
*/ | ||
public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) { | ||
SearchResponseSections searchResponseSections; | ||
if (response.getAggregations() == null || response.getAggregations() instanceof InternalAggregations) { | ||
// We either have no aggregations, or we have Writeable InternalAggregations. | ||
// Either way, we can produce a Writeable InternalSearchResponse. | ||
searchResponseSections = new InternalSearchResponse( | ||
newHits, | ||
(InternalAggregations) response.getAggregations(), | ||
response.getSuggest(), | ||
new SearchProfileShardResults(response.getProfileResults()), | ||
response.isTimedOut(), | ||
response.isTerminatedEarly(), | ||
response.getNumReducePhases() | ||
); | ||
} else { | ||
// We have non-Writeable Aggregations, so the whole SearchResponseSections is non-Writeable. | ||
searchResponseSections = new SearchResponseSections( | ||
newHits, | ||
response.getAggregations(), | ||
response.getSuggest(), | ||
response.isTimedOut(), | ||
response.isTerminatedEarly(), | ||
new SearchProfileShardResults(response.getProfileResults()), | ||
response.getNumReducePhases() | ||
); | ||
} | ||
|
||
return new SearchResponse( | ||
searchResponseSections, | ||
response.getScrollId(), | ||
response.getTotalShards(), | ||
response.getSuccessfulShards(), | ||
response.getSkippedShards(), | ||
response.getTook().millis(), | ||
response.getShardFailures(), | ||
response.getClusters(), | ||
response.pointInTimeId() | ||
); | ||
} | ||
|
||
/** | ||
* Convenience method when only replacing the {@link SearchHit} array within the {@link SearchHits} in a {@link SearchResponse}. | ||
* @param newHits the new array of {@link SearchHit} elements. | ||
* @param response the search response to update | ||
* @return a {@link SearchResponse} where the underlying array of {@link SearchHit} within the {@link SearchHits} has been replaced. | ||
*/ | ||
public static SearchResponse replaceHits(SearchHit[] newHits, SearchResponse response) { | ||
if (response.getHits() == null) { | ||
throw new IllegalStateException("Response must have hits"); | ||
} | ||
SearchHits searchHits = new SearchHits( | ||
newHits, | ||
response.getHits().getTotalHits(), | ||
response.getHits().getMaxScore(), | ||
response.getHits().getSortFields(), | ||
response.getHits().getCollapseField(), | ||
response.getHits().getCollapseValues() | ||
); | ||
return replaceHits(searchHits, response); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.