-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding support for generic re-ranker interface and opensearch ml re-r…
…anker for improving search relavancy. (#494) * Add rerank processor interfaces Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * add cross-encoder specific logic and factory Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * add unittests Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * add integration test Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * use string.format() instead of concatenation Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * rename generateScoringContext to generateRerankingContext Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * add name change in test too. whoops Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * start refactoring with contextSaourceFetchers Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * refactor to use contextSourceFetchers to get context Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * rename CrossEncoder to TextSimilarity Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * add query_context layer to search ext Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * add javadocs Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * update to new asyncProcessResponse api Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * rename reranktype to ML_OPENSEARCH Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * improve error messages for bad rerank type config Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * simplify configuration/factory logic Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * improve handling for non-flat-string context fields Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * rename TextSimilarity files to MLOpenSearch files Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * apply spotless after rebase Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * update changelog Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * after rebase Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * Address pr comments and fix XContent in search ext Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * move contextSourceFetchers to their own subdirectory Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * Apply suggestions from code review Co-authored-by: Martin Gaievski <gaievski@amazon.com> Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * CR changes Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * finish CR comments and fix broken unittest Signed-off-by: HenryL27 <hmlindeman@yahoo.com> * fix unittest names Signed-off-by: HenryL27 <hmlindeman@yahoo.com> --------- Signed-off-by: HenryL27 <hmlindeman@yahoo.com> Co-authored-by: Martin Gaievski <gaievski@amazon.com>
- Loading branch information
1 parent
bd1dfbf
commit 3a7903f
Showing
20 changed files
with
1,833 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
132 changes: 132 additions & 0 deletions
132
src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.factory; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Set; | ||
import java.util.StringJoiner; | ||
|
||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; | ||
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; | ||
import org.opensearch.neuralsearch.processor.rerank.RerankType; | ||
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchResponseProcessor; | ||
|
||
import com.google.common.collect.Sets; | ||
|
||
import lombok.AllArgsConstructor; | ||
|
||
/** | ||
* Factory for rerank processors. Must: | ||
* - Instantiate the right kind of rerank processor | ||
* - Instantiate the appropriate context source fetchers | ||
*/ | ||
@AllArgsConstructor | ||
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> { | ||
|
||
public static final String RERANK_PROCESSOR_TYPE = "rerank"; | ||
public static final String CONTEXT_CONFIG_FIELD = "context"; | ||
|
||
private final MLCommonsClientAccessor clientAccessor; | ||
|
||
@Override | ||
public SearchResponseProcessor create( | ||
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, | ||
final String tag, | ||
final String description, | ||
final boolean ignoreFailure, | ||
final Map<String, Object> config, | ||
final Processor.PipelineContext pipelineContext | ||
) { | ||
RerankType type = findRerankType(config); | ||
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type); | ||
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag); | ||
switch (type) { | ||
case ML_OPENSEARCH: | ||
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel()); | ||
String modelId = ConfigurationUtils.readStringProperty( | ||
RERANK_PROCESSOR_TYPE, | ||
tag, | ||
rerankerConfig, | ||
MLOpenSearchRerankProcessor.MODEL_ID_FIELD | ||
); | ||
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); | ||
default: | ||
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel())); | ||
} | ||
} | ||
|
||
private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException { | ||
// Set of rerank type labels in the config | ||
Set<String> rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet()); | ||
// A rerank type must be provided | ||
if (rerankTypes.size() == 0) { | ||
StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]"); | ||
for (RerankType t : RerankType.values()) { | ||
msgBuilder.add(t.getLabel()); | ||
} | ||
throw new IllegalArgumentException(msgBuilder.toString()); | ||
} | ||
// Only one rerank type may be provided | ||
if (rerankTypes.size() > 1) { | ||
StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted."); | ||
rerankTypes.forEach(rt -> msgBuilder.add(rt)); | ||
throw new IllegalArgumentException(msgBuilder.toString()); | ||
} | ||
return RerankType.from(rerankTypes.iterator().next()); | ||
} | ||
|
||
/** | ||
* Factory class for context fetchers. Constructs a list of context fetchers | ||
* specified in the pipeline config (and maybe the query context fetcher) | ||
*/ | ||
private static class ContextFetcherFactory { | ||
|
||
/** | ||
* Map rerank types to whether they should include the query context source fetcher | ||
* @param type the constructing RerankType | ||
* @return does this RerankType depend on the QueryContextSourceFetcher? | ||
*/ | ||
public static boolean shouldIncludeQueryContextFetcher(RerankType type) { | ||
return type == RerankType.ML_OPENSEARCH; | ||
} | ||
|
||
/** | ||
* Create necessary queryContextFetchers for this processor | ||
* @param config processor config object. Look for "context" field to find fetchers | ||
* @param includeQueryContextFetcher should I include the queryContextFetcher? | ||
* @return list of contextFetchers for the processor to use | ||
*/ | ||
public static List<ContextSourceFetcher> createFetchers( | ||
Map<String, Object> config, | ||
boolean includeQueryContextFetcher, | ||
String tag | ||
) { | ||
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD); | ||
List<ContextSourceFetcher> fetchers = new ArrayList<>(); | ||
for (String key : contextConfig.keySet()) { | ||
Object cfg = contextConfig.get(key); | ||
switch (key) { | ||
case DocumentContextSourceFetcher.NAME: | ||
fetchers.add(DocumentContextSourceFetcher.create(cfg)); | ||
break; | ||
default: | ||
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key)); | ||
} | ||
} | ||
if (includeQueryContextFetcher) { | ||
fetchers.add(new QueryContextSourceFetcher()); | ||
} | ||
return fetchers; | ||
} | ||
} | ||
} |
83 changes: 83 additions & 0 deletions
83
src/main/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.processor.rerank; | ||
|
||
import java.util.List; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; | ||
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; | ||
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher; | ||
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher; | ||
|
||
/** | ||
* Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore | ||
*/ | ||
public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor { | ||
|
||
public static final String MODEL_ID_FIELD = "model_id"; | ||
|
||
protected final String modelId; | ||
|
||
protected final MLCommonsClientAccessor mlCommonsClientAccessor; | ||
|
||
/** | ||
* Constructor | ||
* @param description | ||
* @param tag | ||
* @param ignoreFailure | ||
* @param modelId id of TEXT_SIMILARITY model | ||
* @param contextSourceFetchers | ||
* @param mlCommonsClientAccessor | ||
*/ | ||
public MLOpenSearchRerankProcessor( | ||
final String description, | ||
final String tag, | ||
final boolean ignoreFailure, | ||
final String modelId, | ||
final List<ContextSourceFetcher> contextSourceFetchers, | ||
final MLCommonsClientAccessor mlCommonsClientAccessor | ||
) { | ||
super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers); | ||
this.modelId = modelId; | ||
this.mlCommonsClientAccessor = mlCommonsClientAccessor; | ||
} | ||
|
||
@Override | ||
public void rescoreSearchResponse( | ||
final SearchResponse response, | ||
final Map<String, Object> rerankingContext, | ||
final ActionListener<List<Float>> listener | ||
) { | ||
Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD); | ||
if (!(ctxObj instanceof List<?>)) { | ||
listener.onFailure( | ||
new IllegalStateException( | ||
String.format( | ||
Locale.ROOT, | ||
"No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?", | ||
RerankProcessorFactory.CONTEXT_CONFIG_FIELD, | ||
DocumentContextSourceFetcher.NAME | ||
) | ||
) | ||
); | ||
return; | ||
} | ||
List<?> ctxList = (List<?>) ctxObj; | ||
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList()); | ||
mlCommonsClientAccessor.inferenceSimilarity( | ||
modelId, | ||
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD), | ||
contexts, | ||
listener | ||
); | ||
} | ||
|
||
} |
Oops, something went wrong.