From 9ba967fb24f8195f021f94cec5168c342ca50fa9 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 23 Aug 2023 00:10:19 -0700 Subject: [PATCH] Fixing the backward incompatible changes coming from core in ScoreScript class Signed-off-by: Navneet Verma --- .../knn/plugin/script/KNNScoreScript.java | 20 ++++---- .../plugin/script/KNNScoreScriptFactory.java | 8 +++- .../knn/plugin/script/KNNScoringSpace.java | 47 ++++++++++++------- 3 files changed, 47 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java index f190a3e1db..555a333d04 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScript.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.index.fielddata.ScriptDocValues; @@ -32,9 +33,9 @@ public KNNScoreScript( String field, BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext + LeafReaderContext leafContext, IndexSearcher searcher ) { - super(params, lookup, leafContext); + super(params, lookup, searcher, leafContext); this.queryValue = queryValue; this.field = field; this.scoringMethod = scoringMethod; @@ -51,9 +52,10 @@ public LongType( String field, BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext + LeafReaderContext leafContext, + IndexSearcher searcher ) { - super(params, queryValue, field, scoringMethod, lookup, leafContext); + super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher); } /** @@ -84,9 +86,10 @@ public BigIntegerType( String field, BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext + LeafReaderContext leafContext, + IndexSearcher searcher ) { - super(params, queryValue, field, scoringMethod, lookup, leafContext); + super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher); } /** @@ -118,9 +121,10 @@ public KNNVectorType( String field, BiFunction scoringMethod, SearchLookup lookup, - LeafReaderContext leafContext + LeafReaderContext leafContext, + IndexSearcher searcher ) throws IOException { - super(params, queryValue, field, scoringMethod, lookup, leafContext); + super(params, queryValue, field, scoringMethod, lookup, leafContext, searcher); } /** diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java index b686a20f04..63b367b2d8 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoreScriptFactory.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import org.apache.lucene.search.IndexSearcher; import org.opensearch.knn.plugin.stats.KNNCounter; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.script.ScoreScript; @@ -21,13 +22,16 @@ public class KNNScoreScriptFactory implements ScoreScript.LeafFactory { private Object query; private KNNScoringSpace knnScoringSpace; - public KNNScoreScriptFactory(Map params, SearchLookup lookup) { + private IndexSearcher searcher; + + public KNNScoreScriptFactory(Map params, SearchLookup lookup, IndexSearcher searcher) { KNNCounter.SCRIPT_QUERY_REQUESTS.increment(); this.params = params; this.lookup = lookup; this.field = getValue(params, "field").toString(); this.similaritySpace = getValue(params, "space_type").toString(); this.query = getValue(params, "query_value"); + this.searcher = searcher; this.knnScoringSpace = KNNScoringSpaceFactory.create( this.similaritySpace, @@ -60,6 +64,6 @@ public boolean needs_score() { */ @Override public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return knnScoringSpace.getScoreScript(params, field, lookup, ctx); + return knnScoringSpace.getScoreScript(params, field, lookup, ctx, this.searcher); } } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 16bf6e204e..585069710e 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -5,6 +5,8 @@ package org.opensearch.knn.plugin.script; +import org.apache.lucene.search.IndexSearcher; +import org.opensearch.core.index.Index; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNWeight; import org.apache.lucene.index.LeafReaderContext; @@ -29,14 +31,16 @@ public interface KNNScoringSpace { /** * Return the correct scoring script for a given query. The scoring script * - * @param params Map of parameters - * @param field Fieldname - * @param lookup SearchLookup - * @param ctx ctx LeafReaderContext to be used for scoring documents + * @param params Map of parameters + * @param field Fieldname + * @param lookup SearchLookup + * @param ctx ctx LeafReaderContext to be used for scoring documents + * @param searcher IndexSearcher * @return ScoreScript for this query * @throws IOException throws IOException if ScoreScript cannot be constructed */ - ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) throws IOException; + ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx, + IndexSearcher searcher) throws IOException; class L2 implements KNNScoringSpace { @@ -62,9 +66,10 @@ public L2(Object query, MappedFieldType fieldType) { this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, + ctx, searcher); } } @@ -94,9 +99,10 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, + ctx, searcher); } } @@ -127,7 +133,8 @@ public HammingBit(Object query, MappedFieldType fieldType) { } @SuppressWarnings("unchecked") - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, + LeafReaderContext ctx, IndexSearcher searcher) throws IOException { if (this.processedQuery instanceof Long) { return new KNNScoreScript.LongType( @@ -136,7 +143,8 @@ public ScoreScript getScoreScript(Map params, String field, Sear field, (BiFunction) this.scoringMethod, lookup, - ctx + ctx, + searcher ); } @@ -146,7 +154,7 @@ public ScoreScript getScoreScript(Map params, String field, Sear field, (BiFunction) this.scoringMethod, lookup, - ctx + ctx, searcher ); } } @@ -175,9 +183,10 @@ public L1(Object query, MappedFieldType fieldType) { this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, + ctx, searcher); } } @@ -205,9 +214,10 @@ public LInf(Object query, MappedFieldType fieldType) { this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); } - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, + ctx, searcher); } } @@ -238,9 +248,10 @@ public InnerProd(Object query, MappedFieldType fieldType) { } @Override - public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx) + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException { - return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx); + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, + ctx, searcher); } } }