Skip to content

Convert more PriorityQueues to use Comparator #14761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.function.IntFunction;
import org.apache.lucene.codecs.CodecUtil;
Expand Down Expand Up @@ -453,29 +454,14 @@ public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
}
}

private static class BKDMergeQueue extends PriorityQueue<MergeReader> {
private final int bytesPerDim;

public BKDMergeQueue(int bytesPerDim, int maxSize) {
super(maxSize);
this.bytesPerDim = bytesPerDim;
}

@Override
public boolean lessThan(MergeReader a, MergeReader b) {
assert a != b;

int cmp =
Arrays.compareUnsigned(a.packedValue, 0, bytesPerDim, b.packedValue, 0, bytesPerDim);
if (cmp < 0) {
return true;
} else if (cmp > 0) {
return false;
}

// Tie break by sorting smaller docIDs earlier:
return a.docID < b.docID;
}
private static Comparator<MergeReader> mergeComparator(int bytesPerDim) {
return ((Comparator<MergeReader>)
(a, b) -> {
assert a != b;
return Arrays.compareUnsigned(
a.packedValue, 0, bytesPerDim, b.packedValue, 0, bytesPerDim);
})
.thenComparingInt(mr -> mr.docID);
}

/**
Expand Down Expand Up @@ -642,7 +628,8 @@ public long merge(IndexOutput out, List<MergeState.DocMap> docMaps, List<PointVa
throws IOException {
assert docMaps == null || readers.size() == docMaps.size();

BKDMergeQueue queue = new BKDMergeQueue(config.bytesPerDim(), readers.size());
PriorityQueue<MergeReader> queue =
PriorityQueue.usingComparator(readers.size(), mergeComparator(config.bytesPerDim()));

for (int i = 0; i < readers.size(); i++) {
PointValues pointValues = readers.get(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Objects;
Expand All @@ -40,6 +41,7 @@
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FloatComparator;
import org.apache.lucene.util.PriorityQueue;
import org.apache.lucene.util.automaton.LevenshteinAutomata;

Expand Down Expand Up @@ -125,7 +127,8 @@ public void addTerms(String queryString, String fieldName) {
fieldVals.add(new FieldVals(fieldName, maxEdits, queryString));
}

private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws IOException {
private void addTerms(IndexReader reader, FieldVals f, PriorityQueue<ScoreTerm> q)
throws IOException {
if (f.queryString == null) return;
final Terms terms = MultiTerms.getTerms(reader, f.fieldName);
if (terms == null) {
Expand All @@ -141,8 +144,8 @@ private void addTerms(IndexReader reader, FieldVals f, ScoreTermQueue q) throws
String term = termAtt.toString();
if (!processedTerms.contains(term)) {
processedTerms.add(term);
ScoreTermQueue variantsQ =
new ScoreTermQueue(
PriorityQueue<ScoreTerm> variantsQ =
createScoreTermQueue(
MAX_VARIANTS_PER_TERM); // maxNum variants considered for any one term
float minScore = 0;
Term startTerm = new Term(f.fieldName, term);
Expand Down Expand Up @@ -214,7 +217,7 @@ private Query newTermQuery(IndexReader reader, Term term) throws IOException {
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
ScoreTermQueue q = new ScoreTermQueue(MAX_NUM_TERMS);
PriorityQueue<ScoreTerm> q = createScoreTermQueue(MAX_NUM_TERMS);
// load up the list of possible terms
for (FieldVals f : fieldVals) {
addTerms(reader, f, q);
Expand Down Expand Up @@ -275,19 +278,11 @@ private static class ScoreTerm {
}
}

private static class ScoreTermQueue extends PriorityQueue<ScoreTerm> {
ScoreTermQueue(int size) {
super(size);
}

/* (non-Javadoc)
* @see org.apache.lucene.util.PriorityQueue#lessThan(java.lang.Object, java.lang.Object)
*/
@Override
protected boolean lessThan(ScoreTerm termA, ScoreTerm termB) {
if (termA.score == termB.score) return termA.term.compareTo(termB.term) > 0;
else return termA.score < termB.score;
}
private static PriorityQueue<ScoreTerm> createScoreTermQueue(int size) {
return PriorityQueue.usingComparator(
size,
FloatComparator.<ScoreTerm>comparing(st -> st.score)
.thenComparing(st -> st.term, Comparator.reverseOrder()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,14 @@ private Collection<FieldMetadata> writeSingleSegment(
throws IOException {
List<FieldMetadata> fieldMetadataList =
createFieldMetadataList(new FieldsIterator(fields, fieldInfos), maxDoc);
TermIteratorQueue<FieldTerms> fieldTermsQueue =
PriorityQueue<TermIterator<FieldTerms>> fieldTermsQueue =
createFieldTermsQueue(fields, fieldMetadataList);
List<TermIterator<FieldTerms>> groupedFieldTerms = new ArrayList<>(fieldTermsQueue.size());
List<FieldMetadataTermState> termStates = new ArrayList<>(fieldTermsQueue.size());

while (fieldTermsQueue.size() != 0) {
TermIterator<FieldTerms> topFieldTerms = fieldTermsQueue.popTerms();
TermIterator<FieldTerms> topFieldTerms = fieldTermsQueue.pop();
assert topFieldTerms != null && topFieldTerms.term != null;
BytesRef term = BytesRef.deepCopyOf(topFieldTerms.term);
groupByTerm(fieldTermsQueue, topFieldTerms, groupedFieldTerms);
writePostingLines(term, groupedFieldTerms, normsProducer, termStates);
Expand All @@ -190,9 +191,10 @@ private List<FieldMetadata> createFieldMetadataList(Iterator<FieldInfo> fieldInf
return fieldMetadataList;
}

private TermIteratorQueue<FieldTerms> createFieldTermsQueue(
private PriorityQueue<TermIterator<FieldTerms>> createFieldTermsQueue(
Fields fields, List<FieldMetadata> fieldMetadataList) throws IOException {
TermIteratorQueue<FieldTerms> fieldQueue = new TermIteratorQueue<>(fieldMetadataList.size());
PriorityQueue<TermIterator<FieldTerms>> fieldQueue =
PriorityQueue.usingComparator(fieldMetadataList.size(), Comparator.naturalOrder());
for (FieldMetadata fieldMetadata : fieldMetadataList) {
Terms terms = fields.terms(fieldMetadata.getFieldInfo().name);
if (terms != null) {
Expand All @@ -207,7 +209,7 @@ private TermIteratorQueue<FieldTerms> createFieldTermsQueue(
}

private <T> void groupByTerm(
TermIteratorQueue<T> termIteratorQueue,
PriorityQueue<TermIterator<T>> termIteratorQueue,
TermIterator<T> topTermIterator,
List<TermIterator<T>> groupedTermIterators) {
groupedTermIterators.clear();
Expand Down Expand Up @@ -243,7 +245,8 @@ private void writePostingLines(
}

private <T> void nextTermForIterators(
List<? extends TermIterator<T>> termIterators, TermIteratorQueue<T> termIteratorQueue)
List<? extends TermIterator<T>> termIterators,
PriorityQueue<TermIterator<T>> termIteratorQueue)
throws IOException {
for (TermIterator<T> termIterator : termIterators) {
if (termIterator.nextTerm()) {
Expand Down Expand Up @@ -330,15 +333,17 @@ private Collection<FieldMetadata> mergeSegments(
mergeState.mergeFieldInfos.iterator(), mergeState.segmentInfo.maxDoc());
Map<String, MergingFieldTerms> fieldTermsMap =
createMergingFieldTermsMap(fieldMetadataList, mergeState.fieldsProducers.length);
TermIteratorQueue<SegmentTerms> segmentTermsQueue = createSegmentTermsQueue(segmentTermsList);
PriorityQueue<TermIterator<SegmentTerms>> segmentTermsQueue =
createSegmentTermsQueue(segmentTermsList);
List<TermIterator<SegmentTerms>> groupedSegmentTerms = new ArrayList<>(segmentTermsList.size());
Map<String, List<SegmentPostings>> fieldPostingsMap =
CollectionUtil.newHashMap(mergeState.fieldInfos.length);
List<MergingFieldTerms> groupedFieldTerms = new ArrayList<>(mergeState.fieldInfos.length);
List<FieldMetadataTermState> termStates = new ArrayList<>(mergeState.fieldInfos.length);

while (segmentTermsQueue.size() != 0) {
TermIterator<SegmentTerms> topSegmentTerms = segmentTermsQueue.popTerms();
TermIterator<SegmentTerms> topSegmentTerms = segmentTermsQueue.pop();
assert topSegmentTerms != null && topSegmentTerms.term != null;
BytesRef term = BytesRef.deepCopyOf(topSegmentTerms.term);
groupByTerm(segmentTermsQueue, topSegmentTerms, groupedSegmentTerms);
combineSegmentsFields(groupedSegmentTerms, fieldPostingsMap);
Expand All @@ -364,9 +369,10 @@ private Map<String, MergingFieldTerms> createMergingFieldTermsMap(
return fieldTermsMap;
}

private TermIteratorQueue<SegmentTerms> createSegmentTermsQueue(
private PriorityQueue<TermIterator<SegmentTerms>> createSegmentTermsQueue(
List<TermIterator<SegmentTerms>> segmentTermsList) throws IOException {
TermIteratorQueue<SegmentTerms> segmentQueue = new TermIteratorQueue<>(segmentTermsList.size());
PriorityQueue<TermIterator<SegmentTerms>> segmentQueue =
PriorityQueue.usingComparator(segmentTermsList.size(), Comparator.naturalOrder());
for (TermIterator<SegmentTerms> segmentTerms : segmentTermsList) {
if (segmentTerms.nextTerm()) {
// There is at least one term in the segment
Expand Down Expand Up @@ -447,26 +453,7 @@ PostingsEnum getPostings(String fieldName, PostingsEnum reuse, int flags) throws
}
}

private class TermIteratorQueue<T> extends PriorityQueue<TermIterator<T>> {

TermIteratorQueue(int numFields) {
super(numFields);
}

@Override
protected boolean lessThan(TermIterator<T> a, TermIterator<T> b) {
return a.compareTo(b) < 0;
}

TermIterator<T> popTerms() {
TermIterator<T> topTerms = pop();
assert topTerms != null;
assert topTerms.term != null;
return topTerms;
}
}

private abstract class TermIterator<T> implements Comparable<TermIterator<T>> {
private abstract static class TermIterator<T> implements Comparable<TermIterator<T>> {

BytesRef term;

Expand All @@ -485,7 +472,7 @@ public int compareTo(TermIterator<T> other) {
abstract int compareSecondary(TermIterator<T> other);
}

private class FieldTerms extends TermIterator<FieldTerms> {
private static class FieldTerms extends TermIterator<FieldTerms> {

final FieldMetadata fieldMetadata;
final TermsEnum termsEnum;
Expand Down Expand Up @@ -520,7 +507,7 @@ void resetIterator(BytesRef term, List<SegmentPostings> segmentPostingsList) {
}
}

private class SegmentTerms extends TermIterator<SegmentTerms> {
private static class SegmentTerms extends TermIterator<SegmentTerms> {

private final Integer segmentIndex;
private final STMergingBlockReader mergingBlockReader;
Expand Down
13 changes: 6 additions & 7 deletions lucene/core/src/java/org/apache/lucene/index/DocIDMerger.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,12 @@ private SortedDocIDMerger(List<T> subs, int maxCount) throws IOException {
}
this.subs = subs;
queue =
new PriorityQueue<T>(maxCount - 1) {
@Override
protected boolean lessThan(Sub a, Sub b) {
assert a.mappedDocID != b.mappedDocID;
return a.mappedDocID < b.mappedDocID;
}
};
PriorityQueue.usingComparator(
maxCount - 1,
(a, b) -> {
assert a.mappedDocID != b.mappedDocID;
return Integer.compare(a.mappedDocID, b.mappedDocID);
});
reset();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.util.Comparator;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -159,24 +160,12 @@ public static Iterator mergedIterator(Iterator[] subs) {
return subs[0];
}

// sort by smaller docID, then larger delGen
PriorityQueue<Iterator> queue =
new PriorityQueue<Iterator>(subs.length) {
@Override
protected boolean lessThan(Iterator a, Iterator b) {
// sort by smaller docID
int cmp = Integer.compare(a.docID(), b.docID());
if (cmp == 0) {
// then by larger delGen
cmp = Long.compare(b.delGen(), a.delGen());

// delGens are unique across our subs:
assert cmp != 0;
}

return cmp < 0;
}
};

PriorityQueue.usingComparator(
subs.length,
Comparator.comparingInt(Iterator::docID)
.thenComparing(Comparator.comparingLong(Iterator::delGen).reversed()));
for (Iterator sub : subs) {
if (sub.nextDoc() != NO_MORE_DOCS) {
queue.add(sub);
Expand Down
35 changes: 17 additions & 18 deletions lucene/core/src/java/org/apache/lucene/index/MultiSorter.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.lucene.index;

import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.MergeState.DocMap;
import org.apache.lucene.search.Sort;
Expand Down Expand Up @@ -82,24 +83,22 @@ static MergeState.DocMap[] sort(Sort sort, List<CodecReader> readers) throws IOE
int leafCount = readers.size();

PriorityQueue<LeafAndDocID> queue =
new PriorityQueue<LeafAndDocID>(leafCount) {
@Override
public boolean lessThan(LeafAndDocID a, LeafAndDocID b) {
for (int i = 0; i < comparables.length; i++) {
int cmp = Long.compare(a.valuesAsComparableLongs[i], b.valuesAsComparableLongs[i]);
if (cmp != 0) {
return reverseMuls[i] * cmp < 0;
}
}

// tie-break by docID natural order:
if (a.readerIndex != b.readerIndex) {
return a.readerIndex < b.readerIndex;
} else {
return a.docID < b.docID;
}
}
};
PriorityQueue.usingComparator(
leafCount,
((Comparator<LeafAndDocID>)
(a, b) -> {
for (int i = 0; i < comparables.length; i++) {
int cmp =
Long.compare(
a.valuesAsComparableLongs[i], b.valuesAsComparableLongs[i]);
if (cmp != 0) {
return reverseMuls[i] * cmp;
}
}
return 0;
})
.thenComparingInt(ld -> ld.readerIndex)
.thenComparingInt(ld -> ld.docID));

PackedLongValues.Builder[] builders = new PackedLongValues.Builder[leafCount];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.lucene.util.FloatComparator;
import org.apache.lucene.util.PriorityQueue;

/** Base class for Scorers that score disjunctions. */
Expand Down Expand Up @@ -88,12 +89,7 @@ private TwoPhase(DocIdSetIterator approximation, float matchCost) {
super(approximation);
this.matchCost = matchCost;
unverifiedMatches =
new PriorityQueue<DisiWrapper>(numClauses) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
PriorityQueue.usingComparator(numClauses, FloatComparator.comparing(d -> d.matchCost));
}

DisiWrapper getSubMatches() throws IOException {
Expand Down
Loading