diff --git a/CHANGELOG.md b/CHANGELOG.md index d2b51e573aab0..4e5752536f977 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x] ### Added - Fix for hasInitiatedFetching to fix allocation explain and manual reroute APIs (([#14972](https://github.com/opensearch-project/OpenSearch/pull/14972)) +- [Workload Management] Add queryGroupId to Task ([14708](https://github.com/opensearch-project/OpenSearch/pull/14708)) - Add setting to ignore throttling nodes for allocation of unassigned primaries in remote restore ([#14991](https://github.com/opensearch-project/OpenSearch/pull/14991)) - Add basic aggregation support for derived fields ([#14618](https://github.com/opensearch-project/OpenSearch/pull/14618)) - Add ThreadContextPermission for markAsSystemContext and allow core to perform the method ([#15016](https://github.com/opensearch-project/OpenSearch/pull/15016)) diff --git a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java index dfecf4f462c4d..ed2943db94420 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchShardTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchShardTask.java @@ -37,8 +37,8 @@ import org.opensearch.core.tasks.TaskId; import org.opensearch.search.fetch.ShardFetchSearchRequest; import org.opensearch.search.internal.ShardSearchRequest; -import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.SearchBackpressureTask; +import org.opensearch.wlm.QueryGroupTask; import java.util.Map; import java.util.function.Supplier; @@ -50,7 +50,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class SearchShardTask extends CancellableTask implements SearchBackpressureTask { +public class SearchShardTask extends QueryGroupTask implements SearchBackpressureTask { // generating metadata in a lazy way since source can be quite big private final MemoizedSupplier metadataSupplier; diff --git a/server/src/main/java/org/opensearch/action/search/SearchTask.java b/server/src/main/java/org/opensearch/action/search/SearchTask.java index d3c1043c50cce..2a1a961e7607b 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchTask.java +++ b/server/src/main/java/org/opensearch/action/search/SearchTask.java @@ -35,8 +35,8 @@ import org.opensearch.common.annotation.PublicApi; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.tasks.TaskId; -import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.SearchBackpressureTask; +import org.opensearch.wlm.QueryGroupTask; import java.util.Map; import java.util.function.Supplier; @@ -49,7 +49,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class SearchTask extends CancellableTask implements SearchBackpressureTask { +public class SearchTask extends QueryGroupTask implements SearchBackpressureTask { // generating description in a lazy way since source can be quite big private final Supplier descriptionSupplier; private SearchProgressListener progressListener = SearchProgressListener.NOOP; diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 7d3237d43cd5c..88bf7ebea8e52 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -101,6 +101,7 @@ import org.opensearch.transport.RemoteTransportException; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; +import org.opensearch.wlm.QueryGroupTask; import java.util.ArrayList; import java.util.Arrays; @@ -442,6 +443,12 @@ private void executeRequest( ); searchRequestContext.getSearchRequestOperationsListener().onRequestStart(searchRequestContext); + // At this point either the QUERY_GROUP_ID header will be present in ThreadContext either via ActionFilter + // or HTTP header (HTTP header will be deprecated once ActionFilter is implemented) + if (task instanceof QueryGroupTask) { + ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); + } + PipelinedRequest searchRequest; ActionListener listener; try { diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 205fd34f89c2b..734bb1815b6d4 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -261,6 +261,7 @@ import org.opensearch.transport.TransportService; import org.opensearch.usage.UsageService; import org.opensearch.watcher.ResourceWatcherService; +import org.opensearch.wlm.WorkloadManagementTransportInterceptor; import javax.net.ssl.SNIHostName; @@ -1041,6 +1042,10 @@ protected Node( admissionControlService ); + WorkloadManagementTransportInterceptor workloadManagementTransportInterceptor = new WorkloadManagementTransportInterceptor( + threadPool + ); + final Collection secureSettingsFactories = pluginsService.filterPlugins(Plugin.class) .stream() .map(p -> p.getSecureSettingFactory(settings)) @@ -1048,7 +1053,10 @@ protected Node( .map(Optional::get) .collect(Collectors.toList()); - List transportInterceptors = List.of(admissionControlTransportInterceptor); + List transportInterceptors = List.of( + admissionControlTransportInterceptor, + workloadManagementTransportInterceptor + ); final NetworkModule networkModule = new NetworkModule( settings, pluginsService.filterPlugins(NetworkPlugin.class), diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java b/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java new file mode 100644 index 0000000000000..4eb413be61b72 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupTask.java @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.tasks.CancellableTask; + +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +import static org.opensearch.search.SearchService.NO_TIMEOUT; + +/** + * Base class to define QueryGroup tasks + */ +public class QueryGroupTask extends CancellableTask { + + private static final Logger logger = LogManager.getLogger(QueryGroupTask.class); + public static final String QUERY_GROUP_ID_HEADER = "queryGroupId"; + public static final Supplier DEFAULT_QUERY_GROUP_ID_SUPPLIER = () -> "DEFAULT_QUERY_GROUP"; + private String queryGroupId; + + public QueryGroupTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { + this(id, type, action, description, parentTaskId, headers, NO_TIMEOUT); + } + + public QueryGroupTask( + long id, + String type, + String action, + String description, + TaskId parentTaskId, + Map headers, + TimeValue cancelAfterTimeInterval + ) { + super(id, type, action, description, parentTaskId, headers, cancelAfterTimeInterval); + } + + /** + * This method should always be called after calling setQueryGroupId at least once on this object + * @return task queryGroupId + */ + public final String getQueryGroupId() { + if (queryGroupId == null) { + logger.warn("QueryGroup _id can't be null, It should be set before accessing it. This is abnormal behaviour "); + } + return queryGroupId; + } + + /** + * sets the queryGroupId from threadContext into the task itself, + * This method was defined since the queryGroupId can only be evaluated after task creation + * @param threadContext current threadContext + */ + public final void setQueryGroupId(final ThreadContext threadContext) { + this.queryGroupId = Optional.ofNullable(threadContext) + .map(threadContext1 -> threadContext1.getHeader(QUERY_GROUP_ID_HEADER)) + .orElse(DEFAULT_QUERY_GROUP_ID_SUPPLIER.get()); + } + + @Override + public boolean shouldCancelChildrenOnCancellation() { + return false; + } +} diff --git a/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java b/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java new file mode 100644 index 0000000000000..848df8712549a --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/WorkloadManagementTransportInterceptor.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportInterceptor; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +/** + * This class is used to intercept search traffic requests and populate the queryGroupId header in task headers + */ +public class WorkloadManagementTransportInterceptor implements TransportInterceptor { + private final ThreadPool threadPool; + + public WorkloadManagementTransportInterceptor(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + @Override + public TransportRequestHandler interceptHandler( + String action, + String executor, + boolean forceExecution, + TransportRequestHandler actualHandler + ) { + return new RequestHandler(threadPool, actualHandler); + } + + /** + * This class is mainly used to populate the queryGroupId header + * @param T is Search related request + */ + public static class RequestHandler implements TransportRequestHandler { + + private final ThreadPool threadPool; + TransportRequestHandler actualHandler; + + public RequestHandler(ThreadPool threadPool, TransportRequestHandler actualHandler) { + this.threadPool = threadPool; + this.actualHandler = actualHandler; + } + + @Override + public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { + if (isSearchWorkloadRequest(task)) { + ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); + } + actualHandler.messageReceived(request, channel, task); + } + + boolean isSearchWorkloadRequest(Task task) { + return task instanceof QueryGroupTask; + } + } +} diff --git a/server/src/test/java/org/opensearch/wlm/QueryGroupTaskTests.java b/server/src/test/java/org/opensearch/wlm/QueryGroupTaskTests.java new file mode 100644 index 0000000000000..d292809c30124 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/QueryGroupTaskTests.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import java.util.Collections; + +import static org.opensearch.wlm.QueryGroupTask.DEFAULT_QUERY_GROUP_ID_SUPPLIER; +import static org.opensearch.wlm.QueryGroupTask.QUERY_GROUP_ID_HEADER; + +public class QueryGroupTaskTests extends OpenSearchTestCase { + private ThreadPool threadPool; + private QueryGroupTask sut; + + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()); + sut = new QueryGroupTask(123, "transport", "Search", "test task", null, Collections.emptyMap()); + } + + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + } + + public void testSuccessfulSetQueryGroupId() { + sut.setQueryGroupId(threadPool.getThreadContext()); + assertEquals(DEFAULT_QUERY_GROUP_ID_SUPPLIER.get(), sut.getQueryGroupId()); + + threadPool.getThreadContext().putHeader(QUERY_GROUP_ID_HEADER, "akfanglkaglknag2332"); + + sut.setQueryGroupId(threadPool.getThreadContext()); + assertEquals("akfanglkaglknag2332", sut.getQueryGroupId()); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java new file mode 100644 index 0000000000000..db4e5e45d49ed --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportInterceptorTests.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; +import org.opensearch.wlm.WorkloadManagementTransportInterceptor.RequestHandler; + +import static org.opensearch.threadpool.ThreadPool.Names.SAME; + +public class WorkloadManagementTransportInterceptorTests extends OpenSearchTestCase { + + private ThreadPool threadPool; + private WorkloadManagementTransportInterceptor sut; + + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()); + sut = new WorkloadManagementTransportInterceptor(threadPool); + } + + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + } + + public void testInterceptHandler() { + TransportRequestHandler requestHandler = sut.interceptHandler("Search", SAME, false, null); + assertTrue(requestHandler instanceof RequestHandler); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java new file mode 100644 index 0000000000000..789c02345e774 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/WorkloadManagementTransportRequestHandlerTests.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.action.index.IndexRequest; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportRequest; +import org.opensearch.transport.TransportRequestHandler; + +import java.util.Collections; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; + +public class WorkloadManagementTransportRequestHandlerTests extends OpenSearchTestCase { + private WorkloadManagementTransportInterceptor.RequestHandler sut; + private ThreadPool threadPool; + + private TestTransportRequestHandler actualHandler; + + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getTestName()); + actualHandler = new TestTransportRequestHandler<>(); + + sut = new WorkloadManagementTransportInterceptor.RequestHandler<>(threadPool, actualHandler); + } + + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + } + + public void testMessageReceivedForSearchWorkload() throws Exception { + ShardSearchRequest request = mock(ShardSearchRequest.class); + QueryGroupTask spyTask = getSpyTask(); + + sut.messageReceived(request, mock(TransportChannel.class), spyTask); + assertTrue(sut.isSearchWorkloadRequest(spyTask)); + } + + public void testMessageReceivedForNonSearchWorkload() throws Exception { + IndexRequest indexRequest = mock(IndexRequest.class); + Task task = mock(Task.class); + sut.messageReceived(indexRequest, mock(TransportChannel.class), task); + assertFalse(sut.isSearchWorkloadRequest(task)); + assertEquals(1, actualHandler.invokeCount); + } + + private static QueryGroupTask getSpyTask() { + final QueryGroupTask task = new QueryGroupTask(123, "transport", "Search", "test task", null, Collections.emptyMap()); + + return spy(task); + } + + private static class TestTransportRequestHandler implements TransportRequestHandler { + int invokeCount = 0; + + @Override + public void messageReceived(TransportRequest request, TransportChannel channel, Task task) throws Exception { + invokeCount += 1; + } + }; +}