Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Ruirui Zhang <mariazrr@amazon.com>
  • Loading branch information
ruai0511 committed Sep 10, 2024
1 parent f1d3bcf commit a8bb3a5
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,52 @@
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.HashSet;
import java.util.Set;

/**
* A request to get QueryGroupStats
*/
@ExperimentalApi
public class QueryGroupStatsRequest extends BaseNodesRequest<QueryGroupStatsRequest> {

private final Set<String> queryGroupIds;
private final Boolean breach;

protected QueryGroupStatsRequest(StreamInput in) throws IOException {
super(in);
}

public QueryGroupStatsRequest() {
super(false, (String[]) null);
this.queryGroupIds = new HashSet<>(Set.of(in.readStringArray()));
this.breach = in.readOptionalBoolean();
}

/**
* Get QueryGroup stats from nodes based on the nodes ids specified. If none are passed, stats
* for all nodes will be returned.
*/
public QueryGroupStatsRequest(String... nodesIds) {
super(nodesIds);
public QueryGroupStatsRequest(String[] nodesIds, Set<String> queryGroupIds, boolean breach) {
super(false, nodesIds);
this.queryGroupIds = queryGroupIds;
this.breach = breach;
}

public QueryGroupStatsRequest() {
super(false, (String[]) null);
queryGroupIds = new HashSet<>();
this.breach = false;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeStringArray(queryGroupIds.toArray(new String[0]));
out.writeOptionalBoolean(breach);
}

public Set<String> getQueryGroupIds() {
return queryGroupIds;
}

public boolean isBreach() {
return breach;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -80,7 +81,8 @@ protected QueryGroupStats newNodeResponse(StreamInput in) throws IOException {

@Override
protected QueryGroupStats nodeOperation(NodeQueryGroupStatsRequest nodeQueryGroupStatsRequest) {
return queryGroupService.nodeStats();
QueryGroupStatsRequest request = nodeQueryGroupStatsRequest.request;
return queryGroupService.nodeStats(request.getQueryGroupIds(), request.isBreach());
}

/**
Expand All @@ -106,5 +108,9 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
request.writeTo(out);
}

public DiscoveryNode[] getDiscoveryNodes() {
return this.request.concreteNodes();
}
}
}
10 changes: 7 additions & 3 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -1134,9 +1134,13 @@ protected Node(
taskHeaders,
tracer
);
final QueryGroupService queryGroupService = new QueryGroupService(transportService); // We will need to replace this with actual
// instance of the
// queryGroupService
final QueryGroupService queryGroupService = new QueryGroupService(transportService.getLocalNode(), clusterService); // We will
// need to
// replace
// this with
// actual
// instance of the
// queryGroupService
queryGroupServiceSetOnce.set(queryGroupService);

final QueryGroupRequestOperationListener queryGroupRequestOperationListener = new QueryGroupRequestOperationListener(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Set;

import static java.util.Arrays.asList;
import static java.util.Collections.unmodifiableList;
Expand All @@ -31,7 +32,14 @@ public class RestQueryGroupStatsAction extends BaseRestHandler {

@Override
public List<Route> routes() {
return unmodifiableList(asList(new Route(GET, "/_wlm/query_group_stats"), new Route(GET, "/_wlm/query_group_stats/{nodeId}")));
return unmodifiableList(
asList(
new Route(GET, "query_group/stats"),
new Route(GET, "query_group/stats/{queryGroupId}"),
new Route(GET, "query_group/stats/nodes/{nodeId}"),
new Route(GET, "query_group/stats/{queryGroupId}/nodes/{nodeId}")
)
);
}

@Override
Expand All @@ -42,7 +50,9 @@ public String getName() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
String[] nodesIds = Strings.splitStringByCommaToArray(request.param("nodeId"));
QueryGroupStatsRequest queryGroupStatsRequest = new QueryGroupStatsRequest(nodesIds);
Set<String> queryGroupIds = Strings.tokenizeByCommaToSet(request.param("queryGroupId", "_all"));
Boolean breach = request.hasParam("breach") ? Boolean.parseBoolean(request.param("boolean")) : null;
QueryGroupStatsRequest queryGroupStatsRequest = new QueryGroupStatsRequest(nodesIds, queryGroupIds, breach);
return channel -> client.admin()
.cluster()
.queryGroupStats(queryGroupStatsRequest, new RestActions.NodesResponseRestListener<>(channel));
Expand Down
52 changes: 38 additions & 14 deletions server/src/main/java/org/opensearch/wlm/QueryGroupService.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@

package org.opensearch.wlm;

import org.opensearch.cluster.metadata.QueryGroup;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.transport.TransportService;
import org.opensearch.wlm.stats.QueryGroupState;
import org.opensearch.wlm.stats.QueryGroupStats;
import org.opensearch.wlm.stats.QueryGroupStats.QueryGroupStatsHolder;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/**
* As of now this is a stub and main implementation PR will be raised soon.Coming PR will collate these changes with core QueryGroupService changes
Expand All @@ -26,16 +29,18 @@ public class QueryGroupService {
// This map does not need to be concurrent since we will process the cluster state change serially and update
// this map with new additions and deletions of entries. QueryGroupState is thread safe
private final Map<String, QueryGroupState> queryGroupStateMap;
private final TransportService transportService;
private final DiscoveryNode discoveryNode;
private final ClusterService clusterService;

@Inject
public QueryGroupService(TransportService transportService) {
this(transportService, new HashMap<>());
public QueryGroupService(DiscoveryNode discoveryNode, ClusterService clusterService) {
this(discoveryNode, clusterService, new HashMap<>());
}

@Inject
public QueryGroupService(TransportService transportService, Map<String, QueryGroupState> queryGroupStateMap) {
this.transportService = transportService;
public QueryGroupService(DiscoveryNode discoveryNode, ClusterService clusterService, Map<String, QueryGroupState> queryGroupStateMap) {
this.discoveryNode = discoveryNode;
this.clusterService = clusterService;
this.queryGroupStateMap = queryGroupStateMap;
}

Expand All @@ -54,19 +59,38 @@ public void incrementFailuresFor(final String queryGroupId) {
}

/**
*
* @return node level query group stats
*/
public QueryGroupStats nodeStats() {
public QueryGroupStats nodeStats(Set<String> queryGroupIds, Boolean requestedBreached) {
final Map<String, QueryGroupStatsHolder> statsHolderMap = new HashMap<>();
for (Map.Entry<String, QueryGroupState> queryGroupsState : queryGroupStateMap.entrySet()) {
final String queryGroupId = queryGroupsState.getKey();
final QueryGroupState currentState = queryGroupsState.getValue();

statsHolderMap.put(queryGroupId, QueryGroupStatsHolder.from(currentState));
}
queryGroupStateMap.forEach((queryGroupId, currentState) -> {
boolean shouldInclude = (queryGroupIds.size() == 1 && queryGroupIds.contains("_all")) || queryGroupIds.contains(queryGroupId);

if (shouldInclude) {
if (requestedBreached == null || requestedBreached == resourceLimitBreached(queryGroupId, currentState)) {
statsHolderMap.put(queryGroupId, QueryGroupStatsHolder.from(currentState));
}
}
});

return new QueryGroupStats(discoveryNode, statsHolderMap);
}

/**
* @return if the QueryGroup breaches any resource limit based on the LastRecordedUsage
*/
public boolean resourceLimitBreached(String id, QueryGroupState currentState) {
QueryGroup queryGroup = clusterService.state().metadata().queryGroups().get(id);

return new QueryGroupStats(transportService.getLocalNode(), statsHolderMap);
return currentState.getResourceState()
.entrySet()
.stream()
.anyMatch(
entry -> entry.getValue().getLastRecordedUsage() > queryGroup.getMutableQueryGroupFragment()
.getResourceLimits()
.get(entry.getKey())
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
public class QueryGroupStats extends BaseNodeResponse implements ToXContentObject, Writeable {
private final Map<String, QueryGroupStatsHolder> stats;

public Map<String, QueryGroupStatsHolder> getStats() {
return stats;
}

public QueryGroupStats(DiscoveryNode node, Map<String, QueryGroupStatsHolder> stats) {
super(node);
this.stats = stats;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@

import org.opensearch.action.admin.cluster.wlm.QueryGroupStatsRequest;
import org.opensearch.action.admin.cluster.wlm.TransportQueryGroupStatsAction;
import org.opensearch.action.admin.cluster.wlm.TransportQueryGroupStatsAction.NodeQueryGroupStatsRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.test.transport.CapturingTransport;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.wlm.QueryGroupService;

import java.io.IOException;
Expand All @@ -38,7 +35,7 @@ public class TransportQueryGroupStatsActionTests extends TransportNodesActionTes
*/
public void testQueryGroupStatsActionWithRetentionOfDiscoveryNodesList() {
QueryGroupStatsRequest request = new QueryGroupStatsRequest();
Map<String, List<MockNodeQueryGroupStatsRequest>> combinedSentRequest = performQueryGroupStatsAction(request);
Map<String, List<NodeQueryGroupStatsRequest>> combinedSentRequest = performQueryGroupStatsAction(request);

assertNotNull(combinedSentRequest);
combinedSentRequest.forEach((node, capturedRequestList) -> {
Expand All @@ -47,15 +44,21 @@ public void testQueryGroupStatsActionWithRetentionOfDiscoveryNodesList() {
});
}

private Map<String, List<MockNodeQueryGroupStatsRequest>> performQueryGroupStatsAction(QueryGroupStatsRequest request) {
TransportNodesAction action = getTestTransportQueryGroupStatsAction();
private Map<String, List<NodeQueryGroupStatsRequest>> performQueryGroupStatsAction(QueryGroupStatsRequest request) {
TransportNodesAction action = new TransportQueryGroupStatsAction(
THREAD_POOL,
clusterService,
transportService,
mock(QueryGroupService.class),
new ActionFilters(Collections.emptySet())
);
PlainActionFuture<QueryGroupStatsRequest> listener = new PlainActionFuture<>();
action.new AsyncAction(null, request, listener).start();
Map<String, List<CapturingTransport.CapturedRequest>> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear();
Map<String, List<MockNodeQueryGroupStatsRequest>> combinedSentRequest = new HashMap<>();
Map<String, List<NodeQueryGroupStatsRequest>> combinedSentRequest = new HashMap<>();

capturedRequests.forEach((node, capturedRequestList) -> {
List<MockNodeQueryGroupStatsRequest> sentRequestList = new ArrayList<>();
List<NodeQueryGroupStatsRequest> sentRequestList = new ArrayList<>();

capturedRequestList.forEach(preSentRequest -> {
BytesStreamOutput out = new BytesStreamOutput();
Expand All @@ -64,7 +67,7 @@ private Map<String, List<MockNodeQueryGroupStatsRequest>> performQueryGroupStats
(TransportQueryGroupStatsAction.NodeQueryGroupStatsRequest) preSentRequest.request;
QueryGroupStatsRequestFromCoordinator.writeTo(out);
StreamInput in = out.bytes().streamInput();
MockNodeQueryGroupStatsRequest QueryGroupStatsRequest = new MockNodeQueryGroupStatsRequest(in);
NodeQueryGroupStatsRequest QueryGroupStatsRequest = new NodeQueryGroupStatsRequest(in);
sentRequestList.add(QueryGroupStatsRequest);
} catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -76,37 +79,4 @@ private Map<String, List<MockNodeQueryGroupStatsRequest>> performQueryGroupStats

return combinedSentRequest;
}

private TestTransportQueryGroupStatsAction getTestTransportQueryGroupStatsAction() {
return new TestTransportQueryGroupStatsAction(
THREAD_POOL,
clusterService,
transportService,
mock(QueryGroupService.class),
new ActionFilters(Collections.emptySet())
);
}

private static class TestTransportQueryGroupStatsAction extends TransportQueryGroupStatsAction {
public TestTransportQueryGroupStatsAction(
ThreadPool threadPool,
ClusterService clusterService,
TransportService transportService,
QueryGroupService queryGroupService,
ActionFilters actionFilters
) {
super(threadPool, clusterService, transportService, queryGroupService, actionFilters);
}
}

private static class MockNodeQueryGroupStatsRequest extends TransportQueryGroupStatsAction.NodeQueryGroupStatsRequest {

public MockNodeQueryGroupStatsRequest(StreamInput in) throws IOException {
super(in);
}

public DiscoveryNode[] getDiscoveryNodes() {
return this.request.concreteNodes();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.wlm.listeners;

import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -24,8 +25,10 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
Expand Down Expand Up @@ -100,7 +103,7 @@ public void testMultiThreadedValidQueryGroupRequestFailures() {
TransportService mockTransportService = mock(TransportService.class);
DiscoveryNode mockDiscoveryNode = mock(DiscoveryNode.class);
when(mockTransportService.getLocalNode()).thenReturn(mockDiscoveryNode);
queryGroupService = new QueryGroupService(mockTransportService, queryGroupStateMap);
queryGroupService = new QueryGroupService(mockTransportService.getLocalNode(), mock(ClusterService.class), queryGroupStateMap);

sut = new QueryGroupRequestOperationListener(queryGroupService, testThreadPool);

Expand All @@ -123,7 +126,9 @@ public void testMultiThreadedValidQueryGroupRequestFailures() {
}
});

QueryGroupStats actualStats = queryGroupService.nodeStats();
Set<String> set = new HashSet<>();
set.add("_all");
QueryGroupStats actualStats = queryGroupService.nodeStats(set, null);

QueryGroupStats expectedStats = new QueryGroupStats(
mock(DiscoveryNode.class),
Expand Down Expand Up @@ -184,12 +189,15 @@ private void assertSuccess(
TransportService mockTransportService = mock(TransportService.class);
DiscoveryNode mockDiscoveryNode = mock(DiscoveryNode.class);
when(mockTransportService.getLocalNode()).thenReturn(mockDiscoveryNode);
queryGroupService = new QueryGroupService(mockTransportService, queryGroupStateMap);
queryGroupService = new QueryGroupService(mockTransportService.getLocalNode(), mock(ClusterService.class), queryGroupStateMap);

sut = new QueryGroupRequestOperationListener(queryGroupService, testThreadPool);
sut.onRequestFailure(null, null);

QueryGroupStats actualStats = queryGroupService.nodeStats();
Set<String> set = new HashSet<>();
set.add("_all");
QueryGroupStats actualStats = queryGroupService.nodeStats(set, null);

assertEquals(expectedStats, actualStats);
}

Expand Down

0 comments on commit a8bb3a5

Please sign in to comment.