Skip to content

Commit 99513c0

Browse files
authored
Refine SearchProgressListener internal API (#53373)
The following cumulative improvements have been made: - rename `onReduce` and `notifyReduce` to `onFinalReduce` and `notifyFinalReduce` - add unit test for `SearchShard` - on* methods in `SearchProgressListener` shouldn't need to be public as they should never be called directly, they only need to be overridden hence they can be made protected. They are actually called directly from a test which required some adapting, like making `AsyncSearchTask.Listener` class package private instead of private - Instead of overriding `getProgressListener` in `AsyncSearchTask`, as it feels weird to override a getter method, added a specific method that allows to retrieve the Listener directly without needing to cast it. Made the getter and setter for the listener final in the base class. - rename `SearchProgressListener#searchShards` methods to `buildSearchShards` and make it static given that it accesses no instance members - make `SearchShard` and `SearchShardTask` classes final
1 parent fd030dc commit 99513c0

File tree

15 files changed

+138
-79
lines changed

15 files changed

+138
-79
lines changed

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
5454
this.searchPhaseController = searchPhaseController;
5555
SearchProgressListener progressListener = task.getProgressListener();
5656
SearchSourceBuilder sourceBuilder = request.source();
57-
progressListener.notifyListShards(progressListener.searchShards(this.shardsIts),
58-
progressListener.searchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0);
57+
progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts),
58+
SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0);
5959
}
6060

6161
@Override

server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,8 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
665665
numReducePhases++;
666666
index = 1;
667667
if (hasAggs || hasTopDocs) {
668-
progressListener.notifyPartialReduce(progressListener.searchShards(processedShards),
669-
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases);
668+
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
669+
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases);
670670
}
671671
}
672672
final int i = index++;
@@ -695,7 +695,7 @@ private synchronized List<TopDocs> getRemainingTopDocs() {
695695
public ReducedQueryPhase reduce() {
696696
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(),
697697
getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false, performFinalReduce);
698-
progressListener.notifyReduce(progressListener.searchShards(results.asList()),
698+
progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()),
699699
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
700700
return reducePhase;
701701
}
@@ -751,8 +751,8 @@ ReducedQueryPhase reduce() {
751751
List<SearchPhaseResult> resultList = results.asList();
752752
final ReducedQueryPhase reducePhase =
753753
reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, request.isFinalReduce());
754-
listener.notifyReduce(listener.searchShards(resultList), reducePhase.totalHits,
755-
reducePhase.aggregations, reducePhase.numReducePhases);
754+
listener.notifyFinalReduce(SearchProgressListener.buildSearchShards(resultList),
755+
reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
756756
return reducePhase;
757757
}
758758
};

server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ abstract class SearchProgressListener {
5353
* @param clusters The statistics for remote clusters included in the search.
5454
* @param fetchPhase <code>true</code> if the search needs a fetch phase, <code>false</code> otherwise.
5555
**/
56-
public void onListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) {}
56+
protected void onListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) {}
5757

5858
/**
5959
* Executed when a shard returns a query result.
6060
*
6161
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards} )}.
6262
*/
63-
public void onQueryResult(int shardIndex) {}
63+
protected void onQueryResult(int shardIndex) {}
6464

6565
/**
6666
* Executed when a shard reports a query failure.
@@ -69,7 +69,7 @@ public void onQueryResult(int shardIndex) {}
6969
* @param shardTarget The last shard target that thrown an exception.
7070
* @param exc The cause of the failure.
7171
*/
72-
public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}
72+
protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}
7373

7474
/**
7575
* Executed when a partial reduce is created. The number of partial reduce can be controlled via
@@ -80,7 +80,7 @@ public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Except
8080
* @param aggs The partial result for aggregations.
8181
* @param reducePhase The version number for this reduce.
8282
*/
83-
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
83+
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
8484

8585
/**
8686
* Executed once when the final reduce is created.
@@ -90,22 +90,22 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
9090
* @param aggs The final result for aggregations.
9191
* @param reducePhase The version number for this reduce.
9292
*/
93-
public void onReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
93+
protected void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
9494

9595
/**
9696
* Executed when a shard returns a fetch result.
9797
*
9898
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
9999
*/
100-
public void onFetchResult(int shardIndex) {}
100+
protected void onFetchResult(int shardIndex) {}
101101

102102
/**
103103
* Executed when a shard reports a fetch failure.
104104
*
105105
* @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
106106
* @param exc The cause of the failure.
107107
*/
108-
public void onFetchFailure(int shardIndex, Exception exc) {}
108+
protected void onFetchFailure(int shardIndex, Exception exc) {}
109109

110110
final void notifyListShards(List<SearchShard> shards, List<SearchShard> skippedShards, Clusters clusters, boolean fetchPhase) {
111111
this.shards = shards;
@@ -142,9 +142,9 @@ final void notifyPartialReduce(List<SearchShard> shards, TotalHits totalHits, In
142142
}
143143
}
144144

145-
final void notifyReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
145+
protected final void notifyFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
146146
try {
147-
onReduce(shards, totalHits, aggs, reducePhase);
147+
onFinalReduce(shards, totalHits, aggs, reducePhase);
148148
} catch (Exception e) {
149149
logger.warn(() -> new ParameterizedMessage("Failed to execute progress listener on reduce"), e);
150150
}
@@ -168,22 +168,22 @@ final void notifyFetchFailure(int shardIndex, Exception exc) {
168168
}
169169
}
170170

171-
final List<SearchShard> searchShards(List<? extends SearchPhaseResult> results) {
171+
static List<SearchShard> buildSearchShards(List<? extends SearchPhaseResult> results) {
172172
return results.stream()
173173
.filter(Objects::nonNull)
174174
.map(SearchPhaseResult::getSearchShardTarget)
175175
.map(e -> new SearchShard(e.getClusterAlias(), e.getShardId()))
176176
.collect(Collectors.toUnmodifiableList());
177177
}
178178

179-
final List<SearchShard> searchShards(SearchShardTarget[] results) {
179+
static List<SearchShard> buildSearchShards(SearchShardTarget[] results) {
180180
return Arrays.stream(results)
181181
.filter(Objects::nonNull)
182182
.map(e -> new SearchShard(e.getClusterAlias(), e.getShardId()))
183183
.collect(Collectors.toUnmodifiableList());
184184
}
185185

186-
final List<SearchShard> searchShards(GroupShardsIterator<SearchShardIterator> its) {
186+
static List<SearchShard> buildSearchShards(GroupShardsIterator<SearchShardIterator> its) {
187187
return StreamSupport.stream(its.spliterator(), false)
188188
.map(e -> new SearchShard(e.getClusterAlias(), e.shardId()))
189189
.collect(Collectors.toUnmodifiableList());

server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ final class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<Se
5757
this.progressListener = task.getProgressListener();
5858
final SearchProgressListener progressListener = task.getProgressListener();
5959
final SearchSourceBuilder sourceBuilder = request.source();
60-
progressListener.notifyListShards(progressListener.searchShards(this.shardsIts),
61-
progressListener.searchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0);
60+
progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts),
61+
SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0);
6262
}
6363

6464
protected void executePhaseOnShard(final SearchShardIterator shardIt, final ShardRouting shard,

server/src/main/java/org/elasticsearch/action/search/SearchShard.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
* A class that encapsulates the {@link ShardId} and the cluster alias
3030
* of a shard used during the search action.
3131
*/
32-
public class SearchShard implements Comparable<SearchShard> {
32+
public final class SearchShard implements Comparable<SearchShard> {
3333
@Nullable
3434
private final String clusterAlias;
3535
private final ShardId shardId;
@@ -40,8 +40,7 @@ public SearchShard(@Nullable String clusterAlias, ShardId shardId) {
4040
}
4141

4242
/**
43-
* Return the cluster alias if the shard is on a remote cluster and <code>null</code>
44-
* otherwise (local).
43+
* Return the cluster alias if we are executing a cross cluster search request, <code>null</code> otherwise.
4544
*/
4645
@Nullable
4746
public String getClusterAlias() {
@@ -51,7 +50,6 @@ public String getClusterAlias() {
5150
/**
5251
* Return the {@link ShardId} of this shard.
5352
*/
54-
@Nullable
5553
public ShardId getShardId() {
5654
return shardId;
5755
}

server/src/main/java/org/elasticsearch/action/search/SearchShardTask.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,4 @@ public SearchShardTask(long id, String type, String action, String description,
4040
public boolean shouldCancelChildrenOnCancellation() {
4141
return false;
4242
}
43-
4443
}

server/src/main/java/org/elasticsearch/action/search/SearchTask.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ public SearchTask(long id, String type, String action, String description, TaskI
3737
/**
3838
* Attach a {@link SearchProgressListener} to this task.
3939
*/
40-
public void setProgressListener(SearchProgressListener progressListener) {
40+
public final void setProgressListener(SearchProgressListener progressListener) {
4141
this.progressListener = progressListener;
4242
}
4343

4444
/**
4545
* Return the {@link SearchProgressListener} attached to this task.
4646
*/
47-
public SearchProgressListener getProgressListener() {
47+
public final SearchProgressListener getProgressListener() {
4848
return progressListener;
4949
}
5050

server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
829829
}
830830

831831
@Override
832-
public void onReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
832+
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
833833
totalHitsListener.set(totalHits);
834834
finalAggsListener.set(aggs);
835835
numReduceListener.incrementAndGet();

server/src/test/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
178178
}
179179

180180
@Override
181-
public void onReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
181+
public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
182182
numReduces.incrementAndGet();
183183
}
184184

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.action.search;
21+
22+
import org.elasticsearch.index.Index;
23+
import org.elasticsearch.index.shard.ShardId;
24+
import org.elasticsearch.test.ESTestCase;
25+
import org.elasticsearch.test.EqualsHashCodeTestUtils;
26+
27+
import java.util.ArrayList;
28+
import java.util.Collections;
29+
import java.util.List;
30+
31+
public class SearchShardTests extends ESTestCase {
32+
33+
public void testEqualsAndHashcode() {
34+
String index = randomAlphaOfLengthBetween(5, 10);
35+
SearchShard searchShard = new SearchShard(randomBoolean() ? null : randomAlphaOfLengthBetween(3, 10),
36+
new ShardId(index, index + "-uuid", randomIntBetween(0, 1024)));
37+
EqualsHashCodeTestUtils.checkEqualsAndHashCode(searchShard,
38+
s -> new SearchShard(s.getClusterAlias(), s.getShardId()),
39+
s -> {
40+
if (randomBoolean()) {
41+
return new SearchShard(s.getClusterAlias() == null ? randomAlphaOfLengthBetween(3, 10) : null, s.getShardId());
42+
} else {
43+
String indexName = s.getShardId().getIndexName();
44+
int shardId = s.getShardId().getId();
45+
if (randomBoolean()) {
46+
indexName += randomAlphaOfLength(5);
47+
} else {
48+
shardId += randomIntBetween(1, 1024);
49+
}
50+
return new SearchShard(s.getClusterAlias(), new ShardId(indexName, indexName + "-uuid", shardId));
51+
}
52+
});
53+
}
54+
55+
public void testCompareTo() {
56+
List<SearchShard> searchShards = new ArrayList<>();
57+
Index index0 = new Index("index0", "index0-uuid");
58+
Index index1 = new Index("index1", "index1-uuid");
59+
searchShards.add(new SearchShard(null, new ShardId(index0, 0)));
60+
searchShards.add(new SearchShard(null, new ShardId(index1, 0)));
61+
searchShards.add(new SearchShard(null, new ShardId(index0, 1)));
62+
searchShards.add(new SearchShard(null, new ShardId(index1, 1)));
63+
searchShards.add(new SearchShard(null, new ShardId(index0, 2)));
64+
searchShards.add(new SearchShard(null, new ShardId(index1, 2)));
65+
searchShards.add(new SearchShard("", new ShardId(index0, 0)));
66+
searchShards.add(new SearchShard("", new ShardId(index1, 0)));
67+
searchShards.add(new SearchShard("", new ShardId(index0, 1)));
68+
searchShards.add(new SearchShard("", new ShardId(index1, 1)));
69+
70+
searchShards.add(new SearchShard("remote0", new ShardId(index0, 0)));
71+
searchShards.add(new SearchShard("remote0", new ShardId(index1, 0)));
72+
searchShards.add(new SearchShard("remote0", new ShardId(index0, 1)));
73+
searchShards.add(new SearchShard("remote0", new ShardId(index0, 2)));
74+
searchShards.add(new SearchShard("remote1", new ShardId(index0, 0)));
75+
searchShards.add(new SearchShard("remote1", new ShardId(index1, 0)));
76+
searchShards.add(new SearchShard("remote1", new ShardId(index0, 1)));
77+
searchShards.add(new SearchShard("remote1", new ShardId(index1, 1)));
78+
79+
List<SearchShard> sorted = new ArrayList<>(searchShards);
80+
Collections.sort(sorted);
81+
assertEquals(searchShards, sorted);
82+
}
83+
}

0 commit comments

Comments
 (0)