Skip to content

Commit

Permalink
Add query group level rejection logic (#15428) (#15512)
Browse files Browse the repository at this point in the history
* add rejection listener



* add rejection listener unit test



* add rejection logic for shard level requests



* add changelog entry



* apply spotless check



* remove unused files and fix precommit



* refactor code



* add package info file



* remove unused method from QueryGroupService stub



---------

Signed-off-by: Kaushal Kumar <ravi.kaushal97@gmail.com>
  • Loading branch information
kaushalmahi12 authored Aug 30, 2024
1 parent 3c9e08b commit d1cd7a2
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add concurrent search support for Derived Fields ([#15326](https://github.com/opensearch-project/OpenSearch/pull/15326))
- [Workload Management] Add query group stats constructs ([#15343](https://github.com/opensearch-project/OpenSearch/pull/15343)))
- Add runAs to Subject interface and introduce IdentityAwarePlugin extension point ([#14630](https://github.com/opensearch-project/OpenSearch/pull/14630))
- [Workload Management] Add rejection logic for co-ordinator and shard level requests ([#15428](https://github.com/opensearch-project/OpenSearch/pull/15428)))

### Dependencies
- Bump `netty` from 4.1.111.Final to 4.1.112.Final ([#15081](https://github.com/opensearch-project/OpenSearch/pull/15081))
Expand Down
18 changes: 16 additions & 2 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@
import org.opensearch.transport.TransportService;
import org.opensearch.usage.UsageService;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.wlm.QueryGroupService;
import org.opensearch.wlm.WorkloadManagementTransportInterceptor;
import org.opensearch.wlm.listeners.QueryGroupRequestRejectionOperationListener;

import javax.net.ssl.SNIHostName;

Expand Down Expand Up @@ -996,11 +998,22 @@ protected Node(
List<IdentityAwarePlugin> identityAwarePlugins = pluginsService.filterPlugins(IdentityAwarePlugin.class);
identityService.initializeIdentityAwarePlugins(identityAwarePlugins);

final QueryGroupRequestRejectionOperationListener queryGroupRequestRejectionListener =
new QueryGroupRequestRejectionOperationListener(
new QueryGroupService(), // We will need to replace this with actual instance of the queryGroupService
threadPool
);

// register all standard SearchRequestOperationsCompositeListenerFactory to the SearchRequestOperationsCompositeListenerFactory
final SearchRequestOperationsCompositeListenerFactory searchRequestOperationsCompositeListenerFactory =
new SearchRequestOperationsCompositeListenerFactory(
Stream.concat(
Stream.of(searchRequestStats, searchRequestSlowLog, searchTaskRequestOperationsListener),
Stream.of(
searchRequestStats,
searchRequestSlowLog,
searchTaskRequestOperationsListener,
queryGroupRequestRejectionListener
),
pluginComponents.stream()
.filter(p -> p instanceof SearchRequestOperationsListener)
.map(p -> (SearchRequestOperationsListener) p)
Expand Down Expand Up @@ -1050,7 +1063,8 @@ protected Node(
);

WorkloadManagementTransportInterceptor workloadManagementTransportInterceptor = new WorkloadManagementTransportInterceptor(
threadPool
threadPool,
new QueryGroupService() // We will need to replace this with actual implementation
);

final Collection<SecureSettingsFactory> secureSettingsFactories = pluginsService.filterPlugins(Plugin.class)
Expand Down
32 changes: 32 additions & 0 deletions server/src/main/java/org/opensearch/wlm/QueryGroupService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm;

import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;

/**
* This is stub at this point in time and will be replace by an acutal one in couple of days
*/
public class QueryGroupService {
/**
*
* @param queryGroupId query group identifier
*/
public void rejectIfNeeded(String queryGroupId) {
if (queryGroupId == null) return;
boolean reject = false;
final StringBuilder reason = new StringBuilder();
// TODO: At this point this is dummy and we need to decide whether to cancel the request based on last
// reported resource usage for the queryGroup. We also need to increment the rejection count here for the
// query group
if (reject) {
throw new OpenSearchRejectedExecutionException("QueryGroup " + queryGroupId + " is already contended." + reason.toString());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
*/
public class WorkloadManagementTransportInterceptor implements TransportInterceptor {
private final ThreadPool threadPool;
private final QueryGroupService queryGroupService;

public WorkloadManagementTransportInterceptor(ThreadPool threadPool) {
public WorkloadManagementTransportInterceptor(final ThreadPool threadPool, final QueryGroupService queryGroupService) {
this.threadPool = threadPool;
this.queryGroupService = queryGroupService;
}

@Override
Expand All @@ -32,7 +34,7 @@ public <T extends TransportRequest> TransportRequestHandler<T> interceptHandler(
boolean forceExecution,
TransportRequestHandler<T> actualHandler
) {
return new RequestHandler<T>(threadPool, actualHandler);
return new RequestHandler<T>(threadPool, actualHandler, queryGroupService);
}

/**
Expand All @@ -43,16 +45,20 @@ public static class RequestHandler<T extends TransportRequest> implements Transp

private final ThreadPool threadPool;
TransportRequestHandler<T> actualHandler;
private final QueryGroupService queryGroupService;

public RequestHandler(ThreadPool threadPool, TransportRequestHandler<T> actualHandler) {
public RequestHandler(ThreadPool threadPool, TransportRequestHandler<T> actualHandler, QueryGroupService queryGroupService) {
this.threadPool = threadPool;
this.actualHandler = actualHandler;
this.queryGroupService = queryGroupService;
}

@Override
public void messageReceived(T request, TransportChannel channel, Task task) throws Exception {
if (isSearchWorkloadRequest(task)) {
((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext());
final String queryGroupId = ((QueryGroupTask) (task)).getQueryGroupId();
queryGroupService.rejectIfNeeded(queryGroupId);
}
actualHandler.messageReceived(request, channel, task);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm.listeners;

import org.opensearch.action.search.SearchRequestContext;
import org.opensearch.action.search.SearchRequestOperationsListener;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.QueryGroupService;
import org.opensearch.wlm.QueryGroupTask;

/**
* This listener is used to perform the rejections for incoming requests into a queryGroup
*/
public class QueryGroupRequestRejectionOperationListener extends SearchRequestOperationsListener {

private final QueryGroupService queryGroupService;
private final ThreadPool threadPool;

public QueryGroupRequestRejectionOperationListener(QueryGroupService queryGroupService, ThreadPool threadPool) {
this.queryGroupService = queryGroupService;
this.threadPool = threadPool;
}

/**
* This method assumes that the queryGroupId is already populated in the thread context
* @param searchRequestContext SearchRequestContext instance
*/
@Override
protected void onRequestStart(SearchRequestContext searchRequestContext) {
final String queryGroupId = threadPool.getThreadContext().getHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER);
queryGroupService.rejectIfNeeded(queryGroupId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

/**
* WLM related listener constructs
*/
package org.opensearch.wlm.listeners;
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class WorkloadManagementTransportInterceptorTests extends OpenSearchTestC
public void setUp() throws Exception {
super.setUp();
threadPool = new TestThreadPool(getTestName());
sut = new WorkloadManagementTransportInterceptor(threadPool);
sut = new WorkloadManagementTransportInterceptor(threadPool, new QueryGroupService());
}

public void tearDown() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.wlm;

import org.opensearch.action.index.IndexRequest;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -20,36 +21,49 @@

import java.util.Collections;

import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;

public class WorkloadManagementTransportRequestHandlerTests extends OpenSearchTestCase {
private WorkloadManagementTransportInterceptor.RequestHandler<TransportRequest> sut;
private ThreadPool threadPool;
private QueryGroupService queryGroupService;

private TestTransportRequestHandler<TransportRequest> actualHandler;

public void setUp() throws Exception {
super.setUp();
threadPool = new TestThreadPool(getTestName());
actualHandler = new TestTransportRequestHandler<>();
queryGroupService = mock(QueryGroupService.class);

sut = new WorkloadManagementTransportInterceptor.RequestHandler<>(threadPool, actualHandler);
sut = new WorkloadManagementTransportInterceptor.RequestHandler<>(threadPool, actualHandler, queryGroupService);
}

public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdown();
}

public void testMessageReceivedForSearchWorkload() throws Exception {
public void testMessageReceivedForSearchWorkload_nonRejectionCase() throws Exception {
ShardSearchRequest request = mock(ShardSearchRequest.class);
QueryGroupTask spyTask = getSpyTask();

doNothing().when(queryGroupService).rejectIfNeeded(anyString());
sut.messageReceived(request, mock(TransportChannel.class), spyTask);
assertTrue(sut.isSearchWorkloadRequest(spyTask));
}

public void testMessageReceivedForSearchWorkload_RejectionCase() throws Exception {
ShardSearchRequest request = mock(ShardSearchRequest.class);
QueryGroupTask spyTask = getSpyTask();
doThrow(OpenSearchRejectedExecutionException.class).when(queryGroupService).rejectIfNeeded(anyString());

assertThrows(OpenSearchRejectedExecutionException.class, () -> sut.messageReceived(request, mock(TransportChannel.class), spyTask));
}

public void testMessageReceivedForNonSearchWorkload() throws Exception {
IndexRequest indexRequest = mock(IndexRequest.class);
Task task = mock(Task.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.wlm.listeners;

import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.QueryGroupService;
import org.opensearch.wlm.QueryGroupTask;

import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;

public class QueryGroupRequestRejectionOperationListenerTests extends OpenSearchTestCase {
ThreadPool testThreadPool;
QueryGroupService queryGroupService;
QueryGroupRequestRejectionOperationListener sut;

public void setUp() throws Exception {
super.setUp();
testThreadPool = new TestThreadPool("RejectionTestThreadPool");
queryGroupService = mock(QueryGroupService.class);
sut = new QueryGroupRequestRejectionOperationListener(queryGroupService, testThreadPool);
}

public void tearDown() throws Exception {
super.tearDown();
testThreadPool.shutdown();
}

public void testRejectionCase() {
final String testQueryGroupId = "asdgasgkajgkw3141_3rt4t";
testThreadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, testQueryGroupId);
doThrow(OpenSearchRejectedExecutionException.class).when(queryGroupService).rejectIfNeeded(testQueryGroupId);
assertThrows(OpenSearchRejectedExecutionException.class, () -> sut.onRequestStart(null));
}

public void testNonRejectionCase() {
final String testQueryGroupId = "asdgasgkajgkw3141_3rt4t";
testThreadPool.getThreadContext().putHeader(QueryGroupTask.QUERY_GROUP_ID_HEADER, testQueryGroupId);
doNothing().when(queryGroupService).rejectIfNeeded(testQueryGroupId);

sut.onRequestStart(null);
}
}

0 comments on commit d1cd7a2

Please sign in to comment.