Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,6 @@
import org.opensearch.action.search.PutSearchPipelineTransportAction;
import org.opensearch.action.search.SearchAction;
import org.opensearch.action.search.SearchScrollAction;
import org.opensearch.action.search.StreamSearchAction;
import org.opensearch.action.search.StreamTransportSearchAction;
import org.opensearch.action.search.TransportClearScrollAction;
import org.opensearch.action.search.TransportCreatePitAction;
import org.opensearch.action.search.TransportDeletePitAction;
Expand Down Expand Up @@ -736,9 +734,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> void reg
actions.register(MultiGetAction.INSTANCE, TransportMultiGetAction.class, TransportShardMultiGetAction.class);
actions.register(BulkAction.INSTANCE, TransportBulkAction.class, TransportShardBulkAction.class);
actions.register(SearchAction.INSTANCE, TransportSearchAction.class);
if (FeatureFlags.isEnabled(FeatureFlags.STREAM_TRANSPORT)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we remove it?

actions.register(StreamSearchAction.INSTANCE, StreamTransportSearchAction.class);
}
// Streaming search handled via SearchAction with streamingSearchMode parameter
actions.register(SearchScrollAction.INSTANCE, TransportSearchScrollAction.class);
actions.register(MultiSearchAction.INSTANCE, TransportMultiSearchAction.class);
actions.register(ExplainAction.INSTANCE, TransportExplainAction.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ int getBatchReduceSize(int requestBatchedReduceSize, int minBatchReduceSize) {
return (hasAggs || hasTopDocs) ? Math.min(requestBatchedReduceSize, minBatchReduceSize) : minBatchReduceSize;
}

/**
* Protected accessor for progressListener to allow subclasses to access it.
* @return the search progress listener
*/
protected SearchProgressListener progressListener() {
return this.progressListener;
}

@Override
public void close() {
Releasables.close(pendingReduces);
Expand Down Expand Up @@ -239,6 +247,7 @@ private ReduceResult partialReduce(
}
for (QuerySearchResult result : toConsume) {
TopDocsAndMaxScore topDocs = result.consumeTopDocs();
// For streaming, avoid reassigning shardIndex if already set
SearchPhaseController.setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs);
}
Expand Down Expand Up @@ -273,7 +282,18 @@ private ReduceResult partialReduce(
SearchShardTarget target = result.getSearchShardTarget();
processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
}
progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
// For streaming search with TopDocs, use the new notification method
if (hasTopDocs && newTopDocs != null) {
progressListener.notifyPartialReduceWithTopDocs(
processedShards,
topDocsStats.getTotalHits(),
newTopDocs,
newAggs,
numReducePhases
);
} else {
progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
}
// we leave the results un-serialized because serializing is slow but we compute the serialized
// size as an estimate of the memory used by the newly reduced aggregations.
long serializedSize = hasAggs ? newAggs.getSerializedSize() : 0;
Expand Down Expand Up @@ -564,6 +584,7 @@ private synchronized List<TopDocs> consumeTopDocs() {
}
for (QuerySearchResult result : buffer) {
TopDocsAndMaxScore topDocs = result.consumeTopDocs();
// For streaming, avoid reassigning shardIndex if already set
SearchPhaseController.setShardIndex(topDocs.topDocs, result.getShardIndex());
topDocsList.add(topDocs.topDocs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

package org.opensearch.action.search;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.FieldDoc;
Expand Down Expand Up @@ -90,6 +92,7 @@
* @opensearch.internal
*/
public final class SearchPhaseController {
private static final Logger logger = LogManager.getLogger(SearchPhaseController.class);
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];

private final NamedWriteableRegistry namedWriteableRegistry;
Expand Down Expand Up @@ -246,7 +249,14 @@ static TopDocs mergeTopDocs(Collection<TopDocs> results, int topN, int from) {
}

static void setShardIndex(TopDocs topDocs, int shardIndex) {
assert topDocs.scoreDocs.length == 0 || topDocs.scoreDocs[0].shardIndex == -1 : "shardIndex is already set";
// Idempotent assignment: in streaming flows partial reductions may touch the same TopDocs more than once.
if (topDocs.scoreDocs.length == 0) {
return;
}
if (topDocs.scoreDocs[0].shardIndex != -1) {
// Already set by a previous pass; avoid reassigning to prevent assertion failures
return;
}
for (ScoreDoc doc : topDocs.scoreDocs) {
doc.shardIndex = shardIndex;
}
Expand Down Expand Up @@ -795,40 +805,36 @@ QueryPhaseResultConsumer newSearchPhaseResults(
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isTaskCancelled
) {
return new QueryPhaseResultConsumer(
request,
executor,
circuitBreaker,
this,
listener,
namedWriteableRegistry,
numShards,
onPartialMergeFailure,
isTaskCancelled
);
}

/**
* Returns a new {@link StreamQueryPhaseResultConsumer} instance that reduces search responses incrementally.
*/
StreamQueryPhaseResultConsumer newStreamSearchPhaseResults(
Executor executor,
CircuitBreaker circuitBreaker,
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure
) {
return new StreamQueryPhaseResultConsumer(
request,
executor,
circuitBreaker,
this,
listener,
namedWriteableRegistry,
numShards,
onPartialMergeFailure
);
// Check if this is a streaming search request
String streamingMode = request.getStreamingSearchMode();
if (logger.isDebugEnabled()) {
logger.debug("Streaming mode on request: {}", streamingMode);
}
if (streamingMode != null) {
return new StreamQueryPhaseResultConsumer(
request,
executor,
circuitBreaker,
this,
listener,
namedWriteableRegistry,
numShards,
onPartialMergeFailure
);
} else {
// Regular QueryPhaseResultConsumer
return new QueryPhaseResultConsumer(
request,
executor,
circuitBreaker,
this,
listener,
namedWriteableRegistry,
numShards,
onPartialMergeFailure,
isTaskCancelled
);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,26 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc
*/
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}

/**
* Executed when a partial reduce with TopDocs is created for streaming search.
*
* @param shards The list of shards that are part of this reduce.
* @param totalHits The total number of hits in this reduce.
* @param topDocs The partial TopDocs result (may be null if no docs).
* @param aggs The partial result for aggregations.
* @param reducePhase The version number for this reduce.
*/
protected void onPartialReduceWithTopDocs(
List<SearchShard> shards,
TotalHits totalHits,
org.apache.lucene.search.TopDocs topDocs,
InternalAggregations aggs,
int reducePhase
) {
// Default implementation delegates to the original method for backward compatibility
onPartialReduce(shards, totalHits, aggs, reducePhase);
}

/**
* Executed once when the final reduce is created.
*
Expand Down Expand Up @@ -165,6 +185,20 @@ final void notifyPartialReduce(List<SearchShard> shards, TotalHits totalHits, In
}
}

final void notifyPartialReduceWithTopDocs(
List<SearchShard> shards,
TotalHits totalHits,
org.apache.lucene.search.TopDocs topDocs,
InternalAggregations aggs,
int reducePhase
) {
try {
onPartialReduceWithTopDocs(shards, totalHits, topDocs, aggs, reducePhase);
} catch (Exception e) {
logger.warn(() -> new ParameterizedMessage("Failed to execute progress listener on partial reduce with TopDocs"), e);
}
}

protected final void notifyFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
try {
onFinalReduce(shards, totalHits, aggs, reducePhase);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla

private Boolean phaseTook = null;

private boolean streamingScoring = false;
private String streamingSearchMode = null; // Will use StreamingSearchMode.SCORED_UNSORTED if null

public SearchRequest() {
this.localClusterAlias = null;
this.absoluteStartMillis = DEFAULT_ABSOLUTE_START_MILLIS;
Expand All @@ -145,6 +148,7 @@ public SearchRequest(SearchRequest searchRequest) {
searchRequest.absoluteStartMillis,
searchRequest.finalReduce
);
this.streamingScoring = searchRequest.streamingScoring;
}

/**
Expand Down Expand Up @@ -280,6 +284,14 @@ public SearchRequest(StreamInput in) throws IOException {
if (in.getVersion().onOrAfter(Version.V_2_12_0)) {
phaseTook = in.readOptionalBoolean();
}
// Read streaming fields - gated on version for BWC
if (in.getVersion().onOrAfter(Version.V_3_3_0)) {
streamingScoring = in.readBoolean();
streamingSearchMode = in.readOptionalString();
} else {
streamingScoring = false;
streamingSearchMode = null;
}
}

@Override
Expand Down Expand Up @@ -314,6 +326,11 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion().onOrAfter(Version.V_2_12_0)) {
out.writeOptionalBoolean(phaseTook);
}
// Write streaming fields - gated on version for BWC
if (out.getVersion().onOrAfter(Version.V_3_3_0)) {
out.writeBoolean(streamingScoring);
out.writeOptionalString(streamingSearchMode);
}
}

@Override
Expand Down Expand Up @@ -695,6 +712,36 @@ public void setPhaseTook(Boolean phaseTook) {
this.phaseTook = phaseTook;
}

/**
* Enable streaming scoring for this search request.
*/
public void setStreamingScoring(boolean streamingScoring) {
this.streamingScoring = streamingScoring;
}

/**
* Check if streaming scoring is enabled for this search request.
*/
public boolean isStreamingScoring() {
return streamingScoring;
}

/**
* Sets the streaming search mode for this request.
* @param mode The streaming search mode to use
*/
public void setStreamingSearchMode(String mode) {
this.streamingSearchMode = mode;
}

/**
* Gets the streaming search mode for this request.
* @return The streaming search mode, or null if not set
*/
public String getStreamingSearchMode() {
return streamingSearchMode;
}

/**
* Returns a threshold that enforces a pre-filter roundtrip to pre-filter search shards based on query rewriting if the number of shards
* the search request expands to exceeds the threshold, or <code>null</code> if the threshold is unspecified.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ public class SearchResponse extends ActionResponse implements StatusToXContentOb
private final long tookInMillis;
private final PhaseTook phaseTook;

// Fields for streaming responses
private boolean isPartial = false;
private int sequenceNumber = 0;
private int totalPartials = 0;

public SearchResponse(StreamInput in) throws IOException {
super(in);
internalResponse = new InternalSearchResponse(in);
Expand Down Expand Up @@ -302,6 +307,31 @@ public String getScrollId() {
return scrollId;
}

// Streaming response methods
public boolean isPartial() {
return isPartial;
}

public void setPartial(boolean partial) {
this.isPartial = partial;
}

public int getSequenceNumber() {
return sequenceNumber;
}

public void setSequenceNumber(int sequenceNumber) {
this.sequenceNumber = sequenceNumber;
}

public int getTotalPartials() {
return totalPartials;
}

public void setTotalPartials(int totalPartials) {
this.totalPartials = totalPartials;
}

/**
* Returns the encoded string of the search context that the search request is used to executed
*/
Expand Down
Loading
Loading