Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- [Rule based auto-tagging] Add Delete Rule API ([#18184](https://github.com/opensearch-project/OpenSearch/pull/18184))
- Implement parallel shard refresh behind cluster settings ([#17782](https://github.com/opensearch-project/OpenSearch/pull/17782))
- Bump OpenSearch Core main branch to 3.0.0 ([#18039](https://github.com/opensearch-project/OpenSearch/pull/18039))
- [Rule based Auto-tagging] Add wlm `ActionFilter` ([#17791](https://github.com/opensearch-project/OpenSearch/pull/17791))
- Update API of Message in index to add the timestamp for lag calculation in ingestion polling ([#17977](https://github.com/opensearch-project/OpenSearch/pull/17977/))
- Add Warm Disk Threshold Allocation Decider for Warm shards ([#18082](https://github.com/opensearch-project/OpenSearch/pull/18082))
- Add composite directory factory ([#17988](https://github.com/opensearch-project/OpenSearch/pull/17988))
Expand Down
1 change: 1 addition & 0 deletions modules/autotagging-commons/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


opensearchplugin {
name = "rule-framework"
description = 'OpenSearch Rule Framework plugin'
classname = 'org.opensearch.rule.RuleFrameworkPlugin'
}
Expand Down
2 changes: 1 addition & 1 deletion modules/autotagging-commons/spi/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ base {
}

dependencies {
api project(':modules:autotagging-commons:common')
implementation project(':modules:autotagging-commons:common')
}

disableTasks("forbiddenApisMain")
Expand Down
11 changes: 9 additions & 2 deletions plugins/workload-management/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ apply plugin: 'opensearch.internal-cluster-test'
opensearchplugin {
description = 'OpenSearch Workload Management Plugin.'
classname = 'org.opensearch.plugin.wlm.WorkloadManagementPlugin'
extendedPlugins = [] // Remove autotagging-commons since it's not a plugin
extendedPlugins = ['rule-framework']
}

dependencies {
implementation project(':modules:autotagging-commons:common')
implementation project(':modules:autotagging-commons:spi')
compileOnly project(':modules:autotagging-commons:spi')
compileOnly project(':modules:autotagging-commons')
testImplementation project(':modules:autotagging-commons')
testImplementation project(':modules:autotagging-commons:common')
}

testClusters.all {
testDistribution = 'archive'
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.plugin.wlm;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.IndicesRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilter;
import org.opensearch.action.support.ActionFilterChain;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.plugin.wlm.rule.attribute_extractor.IndicesExtractor;
import org.opensearch.rule.InMemoryRuleProcessingService;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.WorkloadGroupTask;

import java.util.List;
import java.util.Optional;

/**
* This class is responsible to evaluate and assign the WORKLOAD_GROUP_ID header in ThreadContext
*/
public class AutoTaggingActionFilter implements ActionFilter {
private final InMemoryRuleProcessingService ruleProcessingService;
ThreadPool threadPool;

/**
* Main constructor
* @param ruleProcessingService provides access to in memory view of rules
* @param threadPool to access assign the label
*/
public AutoTaggingActionFilter(InMemoryRuleProcessingService ruleProcessingService, ThreadPool threadPool) {
this.ruleProcessingService = ruleProcessingService;
this.threadPool = threadPool;
}

@Override
public int order() {
return Integer.MAX_VALUE;
}

@Override
public <Request extends ActionRequest, Response extends ActionResponse> void apply(
Task task,
String action,
Request request,
ActionListener<Response> listener,
ActionFilterChain<Request, Response> chain
) {
final boolean isValidRequest = request instanceof SearchRequest;

if (!isValidRequest) {
chain.proceed(task, action, request, listener);
return;
}
Optional<String> label = ruleProcessingService.evaluateLabel(List.of(new IndicesExtractor((IndicesRequest) request)));

label.ifPresent(s -> threadPool.getThreadContext().putHeader(WorkloadGroupTask.WORKLOAD_GROUP_ID_HEADER, s));
chain.proceed(task, action, request, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.plugin.wlm;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilter;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -44,10 +45,12 @@
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestHandler;
import org.opensearch.rule.InMemoryRuleProcessingService;
import org.opensearch.rule.RulePersistenceService;
import org.opensearch.rule.autotagging.FeatureType;
import org.opensearch.rule.service.IndexStoredRulePersistenceService;
import org.opensearch.rule.spi.RuleFrameworkExtension;
import org.opensearch.rule.storage.DefaultAttributeValueStore;
import org.opensearch.rule.storage.IndexBasedRuleQueryMapper;
import org.opensearch.rule.storage.XContentRuleParser;
import org.opensearch.script.ScriptService;
Expand Down Expand Up @@ -75,6 +78,8 @@ public class WorkloadManagementPlugin extends Plugin implements ActionPlugin, Sy

private final RulePersistenceServiceHolder rulePersistenceServiceHolder = new RulePersistenceServiceHolder();

private AutoTaggingActionFilter autoTaggingActionFilter;

/**
* Default constructor
*/
Expand All @@ -101,9 +106,19 @@ public Collection<Object> createComponents(
new XContentRuleParser(WorkloadGroupFeatureType.INSTANCE),
new IndexBasedRuleQueryMapper()
);
InMemoryRuleProcessingService ruleProcessingService = new InMemoryRuleProcessingService(
WorkloadGroupFeatureType.INSTANCE,
DefaultAttributeValueStore::new
);
autoTaggingActionFilter = new AutoTaggingActionFilter(ruleProcessingService, threadPool);
return Collections.emptyList();
}

@Override
public List<ActionFilter> getActionFilters() {
return List.of(autoTaggingActionFilter);
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return List.of(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.plugin.wlm;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilterChain;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.rule.InMemoryRuleProcessingService;
import org.opensearch.rule.autotagging.Attribute;
import org.opensearch.rule.autotagging.FeatureType;
import org.opensearch.rule.storage.DefaultAttributeValueStore;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.wlm.WorkloadGroupTask;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;

import static org.mockito.Mockito.anyList;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class AutoTaggingActionFilterTests extends OpenSearchTestCase {

AutoTaggingActionFilter autoTaggingActionFilter;
InMemoryRuleProcessingService ruleProcessingService;
ThreadPool threadPool;

public void setUp() throws Exception {
super.setUp();
threadPool = new TestThreadPool("AutoTaggingActionFilterTests");
ruleProcessingService = spy(new InMemoryRuleProcessingService(WLMFeatureType.WLM, DefaultAttributeValueStore::new));
autoTaggingActionFilter = new AutoTaggingActionFilter(ruleProcessingService, threadPool);
}

public void tearDown() throws Exception {
super.tearDown();
threadPool.shutdownNow();
}

public void testOrder() {
assertEquals(Integer.MAX_VALUE, autoTaggingActionFilter.order());
}

public void testApplyForValidRequest() {
SearchRequest request = mock(SearchRequest.class);
ActionFilterChain<ActionRequest, ActionResponse> mockFilterChain = mock(TestActionFilterChain.class);
when(request.indices()).thenReturn(new String[] { "foo" });
try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) {
when(ruleProcessingService.evaluateLabel(anyList())).thenReturn(Optional.of("TestQG_ID"));
autoTaggingActionFilter.apply(mock(Task.class), "Test", request, null, mockFilterChain);

assertEquals("TestQG_ID", threadPool.getThreadContext().getHeader(WorkloadGroupTask.WORKLOAD_GROUP_ID_HEADER));
verify(ruleProcessingService, times(1)).evaluateLabel(anyList());
}
}

public void testApplyForInValidRequest() {
ActionFilterChain<ActionRequest, ActionResponse> mockFilterChain = mock(TestActionFilterChain.class);
CancelTasksRequest request = new CancelTasksRequest();
autoTaggingActionFilter.apply(mock(Task.class), "Test", request, null, mockFilterChain);

verify(ruleProcessingService, times(0)).evaluateLabel(anyList());
}

public enum WLMFeatureType implements FeatureType {
WLM;

@Override
public String getName() {
return "";
}

@Override
public Map<String, Attribute> getAllowedAttributesRegistry() {
return Map.of("test_attribute", TestAttribute.TEST_ATTRIBUTE);
}

@Override
public void registerFeatureType() {}
}

public enum TestAttribute implements Attribute {
TEST_ATTRIBUTE("test_attribute"),
INVALID_ATTRIBUTE("invalid_attribute");

private final String name;

TestAttribute(String name) {
this.name = name;
}

@Override
public String getName() {
return name;
}

@Override
public void validateAttribute() {}

@Override
public void writeTo(StreamOutput out) throws IOException {}
}

private static class TestActionFilterChain implements ActionFilterChain<ActionRequest, ActionResponse> {
@Override
public void proceed(Task task, String action, ActionRequest request, ActionListener<ActionResponse> listener) {

}
}
}
Loading