Skip to content

Commit 6468464

Browse files
Approximation framework to support numeric search_after queries (#18896)
* Initial commit for search_after queries Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Test with increment and decrement with search_after Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * search_after queries Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Support framework for search_after queries Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Fix gradle precommit issues Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Add comments Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Attempt to fix the test Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Attempt to fix the test Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Update CHANGELOG.md Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Update CHANGELOG.md and fetch upstream Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Update tests validating with lucene searchAfter Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Update tests validating with lucene searchAfter Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Update tests validating with lucene searchAfter Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Update code with comments Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Add encode tests Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Fix conflicts Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Fix spotless Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Code clean up Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Upstream fetch Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Upstream Fetch Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Add clamps Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> * Upstream Fetch to resolve conflicts Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com> --------- Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com>
1 parent c5d26f7 commit 6468464

File tree

9 files changed

+550
-12
lines changed

9 files changed

+550
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
5353
- [Star-Tree] Add search support for ip field type ([#18671](https://github.com/opensearch-project/OpenSearch/pull/18671))
5454
- [Derived Source] Add integration of derived source feature across various paths like get/search/recovery ([#18565](https://github.com/opensearch-project/OpenSearch/pull/18565))
5555
- Supporting Scripted Metric Aggregation when reducing aggregations in InternalValueCount and InternalAvg ([18411](https://github.com/opensearch-project/OpenSearch/pull18411)))
56+
- Support `search_after` numeric queries with Approximation Framework ([#18896](https://github.com/opensearch-project/OpenSearch/pull/18896))
5657

5758
### Changed
5859
- Update Subject interface to use CheckedRunnable ([#18570](https://github.com/opensearch-project/OpenSearch/issues/18570))

modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ public byte[] encodePoint(Number value) {
217217
return point;
218218
}
219219

220+
@Override
221+
public byte[] encodePoint(Object value, boolean roundUp) {
222+
double doubleValue = parse(value);
223+
long scaledValue = Math.round(scale(doubleValue));
224+
if (roundUp) {
225+
if (scaledValue < Long.MAX_VALUE) {
226+
scaledValue = scaledValue + 1;
227+
}
228+
} else {
229+
if (scaledValue > Long.MIN_VALUE) {
230+
scaledValue = scaledValue - 1;
231+
}
232+
}
233+
return encodePoint(scaledValue);
234+
}
235+
220236
public double getScalingFactor() {
221237
return scalingFactor;
222238
}

server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,23 @@ public byte[] encodePoint(Number value) {
619619
return point;
620620
}
621621

622+
@Override
623+
public byte[] encodePoint(Object value, boolean roundUp) {
624+
// Always parse with roundUp=false to get consistent date math
625+
// In this method the parseToLong is only used for date math rounding operations
626+
long timestamp = parseToLong(value, false, null, null, null);
627+
if (roundUp) {
628+
if (timestamp < Long.MAX_VALUE) {
629+
timestamp = timestamp + 1;
630+
}
631+
} else {
632+
if (timestamp > Long.MIN_VALUE) {
633+
timestamp = timestamp - 1;
634+
}
635+
}
636+
return encodePoint(timestamp);
637+
}
638+
622639
@Override
623640
public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
624641
failIfNotIndexedAndNoDocValues();

server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,17 @@ public byte[] encodePoint(Number value) {
283283
return point;
284284
}
285285

286+
@Override
287+
public byte[] encodePoint(Object value, boolean roundUp) {
288+
Float numericValue = parse(value, true);
289+
if (roundUp) {
290+
numericValue = HalfFloatPoint.nextUp(numericValue);
291+
} else {
292+
numericValue = HalfFloatPoint.nextDown(numericValue);
293+
}
294+
return encodePoint(numericValue);
295+
}
296+
286297
@Override
287298
public double toDoubleValue(long value) {
288299
return HalfFloatPoint.sortableShortToHalfFloat((short) value);
@@ -459,6 +470,17 @@ public byte[] encodePoint(Number value) {
459470
return point;
460471
}
461472

473+
@Override
474+
public byte[] encodePoint(Object value, boolean roundUp) {
475+
Float numericValue = parse(value, true);
476+
if (roundUp) {
477+
numericValue = FloatPoint.nextUp(numericValue);
478+
} else {
479+
numericValue = FloatPoint.nextDown(numericValue);
480+
}
481+
return encodePoint(numericValue);
482+
}
483+
462484
@Override
463485
public double toDoubleValue(long value) {
464486
return NumericUtils.sortableIntToFloat((int) value);
@@ -626,6 +648,17 @@ public byte[] encodePoint(Number value) {
626648
return point;
627649
}
628650

651+
@Override
652+
public byte[] encodePoint(Object value, boolean roundUp) {
653+
Double numericValue = parse(value, true);
654+
if (roundUp) {
655+
numericValue = DoublePoint.nextUp(numericValue);
656+
} else {
657+
numericValue = DoublePoint.nextDown(numericValue);
658+
}
659+
return encodePoint(numericValue);
660+
}
661+
629662
@Override
630663
public double toDoubleValue(long value) {
631664
return NumericUtils.sortableLongToDouble(value);
@@ -789,6 +822,23 @@ public byte[] encodePoint(Number value) {
789822
return point;
790823
}
791824

825+
@Override
826+
public byte[] encodePoint(Object value, boolean roundUp) {
827+
Byte numericValue = parse(value, true);
828+
if (roundUp) {
829+
// ASC: exclusive lower bound
830+
if (numericValue < Byte.MAX_VALUE) {
831+
numericValue = (byte) (numericValue + 1);
832+
}
833+
} else {
834+
// DESC: exclusive upper bound
835+
if (numericValue > Byte.MIN_VALUE) {
836+
numericValue = (byte) (numericValue - 1);
837+
}
838+
}
839+
return encodePoint(numericValue);
840+
}
841+
792842
@Override
793843
public double toDoubleValue(long value) {
794844
return objectToDouble(value);
@@ -873,6 +923,22 @@ public byte[] encodePoint(Number value) {
873923
return point;
874924
}
875925

926+
@Override
927+
public byte[] encodePoint(Object value, boolean roundUp) {
928+
Short numericValue = parse(value, true);
929+
if (roundUp) {
930+
// ASC: exclusive lower bound
931+
if (numericValue < Short.MAX_VALUE) {
932+
numericValue = (short) (numericValue + 1);
933+
}
934+
} else {
935+
if (numericValue > Short.MIN_VALUE) {
936+
numericValue = (short) (numericValue - 1);
937+
}
938+
}
939+
return encodePoint(numericValue);
940+
}
941+
876942
@Override
877943
public double toDoubleValue(long value) {
878944
return (double) value;
@@ -953,6 +1019,23 @@ public byte[] encodePoint(Number value) {
9531019
return point;
9541020
}
9551021

1022+
@Override
1023+
public byte[] encodePoint(Object value, boolean roundUp) {
1024+
Integer numericValue = parse(value, true);
1025+
// Always apply exclusivity
1026+
if (roundUp) {
1027+
if (numericValue < Integer.MAX_VALUE) {
1028+
numericValue = numericValue + 1;
1029+
}
1030+
} else {
1031+
if (numericValue > Integer.MIN_VALUE) {
1032+
numericValue = numericValue - 1;
1033+
}
1034+
}
1035+
1036+
return encodePoint(numericValue);
1037+
}
1038+
9561039
@Override
9571040
public double toDoubleValue(long value) {
9581041
return (double) value;
@@ -1139,6 +1222,23 @@ public byte[] encodePoint(Number value) {
11391222
return point;
11401223
}
11411224

1225+
@Override
1226+
public byte[] encodePoint(Object value, boolean roundUp) {
1227+
Long numericValue = parse(value, true);
1228+
if (roundUp) {
1229+
// ASC: exclusive lower bound
1230+
if (numericValue < Long.MAX_VALUE) {
1231+
numericValue = numericValue + 1;
1232+
}
1233+
} else {
1234+
// DESC: exclusive upper bound
1235+
if (numericValue > Long.MIN_VALUE) {
1236+
numericValue = numericValue - 1;
1237+
}
1238+
}
1239+
return encodePoint(numericValue);
1240+
}
1241+
11421242
@Override
11431243
public double toDoubleValue(long value) {
11441244
return (double) value;
@@ -1281,6 +1381,22 @@ public byte[] encodePoint(Number value) {
12811381
return point;
12821382
}
12831383

1384+
@Override
1385+
public byte[] encodePoint(Object value, boolean roundUp) {
1386+
BigInteger numericValue = parse(value, true);
1387+
if (roundUp) {
1388+
if (numericValue.compareTo(Numbers.MAX_UNSIGNED_LONG_VALUE) < 0) {
1389+
numericValue = numericValue.add(BigInteger.ONE);
1390+
}
1391+
} else {
1392+
// DESC: exclusive upper bound
1393+
if (numericValue.compareTo(Numbers.MIN_UNSIGNED_LONG_VALUE) > 0) {
1394+
numericValue = numericValue.subtract(BigInteger.ONE);
1395+
}
1396+
}
1397+
return encodePoint(numericValue);
1398+
}
1399+
12841400
@Override
12851401
public double toDoubleValue(long value) {
12861402
return Numbers.unsignedLongToDouble(value);
@@ -1851,6 +1967,11 @@ public byte[] encodePoint(Number value) {
18511967
return type.encodePoint(value);
18521968
}
18531969

1970+
@Override
1971+
public byte[] encodePoint(Object value, boolean roundUp) {
1972+
return type.encodePoint(value, roundUp);
1973+
}
1974+
18541975
@Override
18551976
public double toDoubleValue(long value) {
18561977
return type.toDoubleValue(value);

server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,12 @@
1313
*/
1414
public interface NumericPointEncoder {
1515
byte[] encodePoint(Number value);
16+
17+
/**
18+
* Encodes an Object value to byte array for Approximation Framework search_after optimization.
19+
* @param value the search_after value as Object
20+
* @param roundUp whether to round up (for lower bounds) or down (for upper bounds)
21+
* @return encoded byte array
22+
*/
23+
byte[] encodePoint(Object value, boolean roundUp);
1624
}

server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import org.apache.lucene.util.ArrayUtil;
3232
import org.apache.lucene.util.DocIdSetBuilder;
3333
import org.apache.lucene.util.IntsRef;
34+
import org.opensearch.index.mapper.MappedFieldType;
35+
import org.opensearch.index.mapper.NumericPointEncoder;
3436
import org.opensearch.search.internal.SearchContext;
3537
import org.opensearch.search.sort.FieldSortBuilder;
3638
import org.opensearch.search.sort.SortOrder;
@@ -52,10 +54,9 @@ public class ApproximatePointRangeQuery extends ApproximateQuery {
5254
public static final Function<byte[], String> UNSIGNED_LONG_FORMAT = bytes -> BigIntegerPoint.decodeDimension(bytes, 0).toString();
5355

5456
private int size;
55-
5657
private SortOrder sortOrder;
57-
58-
public final PointRangeQuery pointRangeQuery;
58+
public PointRangeQuery pointRangeQuery;
59+
private final Function<byte[], String> valueToString;
5960

6061
public ApproximatePointRangeQuery(
6162
String field,
@@ -78,6 +79,7 @@ protected ApproximatePointRangeQuery(
7879
) {
7980
this.size = size;
8081
this.sortOrder = sortOrder;
82+
this.valueToString = valueToString;
8183
this.pointRangeQuery = new PointRangeQuery(field, lowerPoint, upperPoint, numDims) {
8284
@Override
8385
protected String toString(int dimension, byte[] value) {
@@ -114,12 +116,12 @@ public void visit(QueryVisitor visitor) {
114116

115117
@Override
116118
public final ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
119+
final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(pointRangeQuery.getBytesPerDim());
120+
117121
Weight pointRangeQueryWeight = pointRangeQuery.createWeight(searcher, scoreMode, boost);
118122

119123
return new ConstantScoreWeight(this, boost) {
120124

121-
private final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(pointRangeQuery.getBytesPerDim());
122-
123125
// we pull this from PointRangeQuery since it is final
124126
private boolean matches(byte[] packedValue) {
125127
for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) {
@@ -138,7 +140,6 @@ private boolean matches(byte[] packedValue) {
138140

139141
// we pull this from PointRangeQuery since it is final
140142
private PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) {
141-
142143
boolean crosses = false;
143144

144145
for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) {
@@ -352,6 +353,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
352353
if (checkValidPointValues(values) == false) {
353354
return null;
354355
}
356+
// values.size(): total points indexed, In most cases: values.size() ≈ number of documents (assuming single-valued fields)
355357
if (size > values.size()) {
356358
return pointRangeQueryWeight.scorerSupplier(context);
357359
} else {
@@ -423,6 +425,19 @@ public boolean isCacheable(LeafReaderContext ctx) {
423425
};
424426
}
425427

428+
private byte[] computeEffectiveBound(SearchContext context, boolean isLowerBound) {
429+
byte[] originalBound = isLowerBound ? pointRangeQuery.getLowerPoint() : pointRangeQuery.getUpperPoint();
430+
boolean isAscending = sortOrder == null || sortOrder.equals(SortOrder.ASC);
431+
if ((isLowerBound && isAscending) || (isLowerBound == false && isAscending == false)) {
432+
Object searchAfterValue = context.request().source().searchAfter()[0];
433+
MappedFieldType fieldType = context.getQueryShardContext().fieldMapper(pointRangeQuery.getField());
434+
if (fieldType instanceof NumericPointEncoder encoder) {
435+
return encoder.encodePoint(searchAfterValue, isLowerBound);
436+
}
437+
}
438+
return originalBound;
439+
}
440+
426441
@Override
427442
public boolean canApproximate(SearchContext context) {
428443
if (context == null) {
@@ -435,7 +450,6 @@ public boolean canApproximate(SearchContext context) {
435450
if (context.trackTotalHitsUpTo() == SearchContext.TRACK_TOTAL_HITS_ACCURATE) {
436451
return false;
437452
}
438-
439453
// size 0 could be set for caching
440454
if (context.from() + context.size() == 0) {
441455
this.setSize(SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO);
@@ -459,12 +473,24 @@ public boolean canApproximate(SearchContext context) {
459473
// Cannot sort documents missing this field.
460474
return false;
461475
}
476+
this.setSortOrder(primarySortField.order());
462477
if (context.request().source().searchAfter() != null) {
463-
// TODO: We *could* optimize searchAfter, especially when this is the only sort field, but existing pruning is pretty
464-
// good.
465-
return false;
478+
byte[] lower;
479+
byte[] upper;
480+
if (sortOrder == SortOrder.ASC) {
481+
lower = computeEffectiveBound(context, true);
482+
upper = pointRangeQuery.getUpperPoint();
483+
} else {
484+
lower = pointRangeQuery.getLowerPoint();
485+
upper = computeEffectiveBound(context, false);
486+
}
487+
this.pointRangeQuery = new PointRangeQuery(pointRangeQuery.getField(), lower, upper, pointRangeQuery.getNumDims()) {
488+
@Override
489+
protected String toString(int dimension, byte[] value) {
490+
return valueToString.apply(value);
491+
}
492+
};
466493
}
467-
this.setSortOrder(primarySortField.order());
468494
}
469495
return context.request().source().terminateAfter() == SearchContext.DEFAULT_TERMINATE_AFTER;
470496
}

0 commit comments

Comments
 (0)