Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Segment Replication] Allocation and rebalancing based on average primary shard count per index #6422

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added
- Add GeoTile and GeoHash Grid aggregations on GeoShapes. ([#5589](https://github.com/opensearch-project/OpenSearch/pull/5589))
- Disallow multiple data paths for search nodes ([#6427](https://github.com/opensearch-project/OpenSearch/pull/6427))
- [Segment Replication] Allocation and rebalancing based on average primary shard count per index ([#6422](https://github.com/opensearch-project/OpenSearch/pull/6422))

### Dependencies
- Bump `org.apache.logging.log4j:log4j-core` from 2.18.0 to 2.20.0 ([#6490](https://github.com/opensearch-project/OpenSearch/pull/6490))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.indices.replication;

import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.routing.IndexRoutingTable;
import org.opensearch.cluster.routing.RoutingNode;
import org.opensearch.cluster.routing.RoutingNodes;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.index.IndexModule;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.test.InternalTestCluster;
import org.opensearch.test.OpenSearchIntegTestCase;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.opensearch.cluster.routing.ShardRoutingState.STARTED;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;

import org.opensearch.cluster.OpenSearchAllocationTestCase.ShardAllocations;

@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 0)
public class SegmentReplicationAllocationIT extends SegmentReplicationBaseIT {

private void createIndex(String idxName, int shardCount, int replicaCount, boolean isSegRep) {
Settings.Builder builder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, shardCount)
.put(IndexModule.INDEX_QUERY_CACHE_ENABLED_SETTING.getKey(), false)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, replicaCount);
if (isSegRep) {
builder = builder.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT);
} else {
builder = builder.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.DOCUMENT);
}
prepareCreate(idxName, builder).get();
}

public void enablePreferPrimaryBalance() {
assertAcked(
client().admin()
.cluster()
.prepareUpdateSettings()
.setPersistentSettings(
Settings.builder().put(BalancedShardsAllocator.PREFER_PER_INDEX_PRIMARY_SHARD_BALANCE.getKey(), "true")
)
);
}

/**
* This test verifies the happy path where primary shard allocation is balanced when multiple indices are created.
*
* This test in general passes without primary shard balance as well due to nature of allocation algorithm which
* assigns all primary shards first followed by replica copies.
*/
public void testBalancedPrimaryAllocation() throws Exception {
internalCluster().startClusterManagerOnlyNode();
final int maxReplicaCount = 2;
final int maxShardCount = 5;
final int nodeCount = randomIntBetween(maxReplicaCount + 1, 10);
final int numberOfIndices = randomIntBetween(5, 10);

final List<String> nodeNames = new ArrayList<>();
logger.info("--> Creating {} nodes", nodeCount);
for (int i = 0; i < nodeCount; i++) {
nodeNames.add(internalCluster().startNode());
}
enablePreferPrimaryBalance();
int shardCount, replicaCount;
ClusterState state;
for (int i = 0; i < numberOfIndices; i++) {
shardCount = randomIntBetween(1, maxShardCount);
replicaCount = randomIntBetween(0, maxReplicaCount);
createIndex("test" + i, shardCount, replicaCount, i % 2 == 0);
logger.info("--> Creating index {} with shard count {} and replica count {}", "test" + i, shardCount, replicaCount);
ensureGreen(TimeValue.timeValueSeconds(60));
state = client().admin().cluster().prepareState().execute().actionGet().getState();
logger.info(ShardAllocations.printShardDistribution(state));
}
verifyPerIndexPrimaryBalance();
}

/**
* This test verifies balanced primary shard allocation for a single index with large shard count in event of node
* going down and a new node joining the cluster. The results in shard distribution skewness and re-balancing logic
* ensures the primary shard distribution is balanced.
*
*/
public void testSingleIndexShardAllocation() throws Exception {
internalCluster().startClusterManagerOnlyNode();
final int maxReplicaCount = 1;
final int maxShardCount = 50;
final int nodeCount = 5;

final List<String> nodeNames = new ArrayList<>();
logger.info("--> Creating {} nodes", nodeCount);
for (int i = 0; i < nodeCount; i++) {
nodeNames.add(internalCluster().startNode());
}
enablePreferPrimaryBalance();

ClusterState state;
createIndex("test", maxShardCount, maxReplicaCount, true);
ensureGreen(TimeValue.timeValueSeconds(60));
state = client().admin().cluster().prepareState().execute().actionGet().getState();
logger.info(ShardAllocations.printShardDistribution(state));
verifyPerIndexPrimaryBalance();

// Remove a node
internalCluster().stopRandomNode(InternalTestCluster.nameFilter(nodeNames.get(0)));
ensureGreen(TimeValue.timeValueSeconds(60));
state = client().admin().cluster().prepareState().execute().actionGet().getState();
logger.info(ShardAllocations.printShardDistribution(state));
verifyPerIndexPrimaryBalance();

// Add a new node
internalCluster().startDataOnlyNode();
ensureGreen(TimeValue.timeValueSeconds(60));
state = client().admin().cluster().prepareState().execute().actionGet().getState();
logger.info(ShardAllocations.printShardDistribution(state));
verifyPerIndexPrimaryBalance();
}

/**
* Similar to testSingleIndexShardAllocation test but creates multiple indices, multiple node adding in and getting
* removed. The test asserts post each such event that primary shard distribution is balanced across single index.
*/
public void testAllocationWithDisruption() throws Exception {
internalCluster().startClusterManagerOnlyNode();
final int maxReplicaCount = 2;
final int maxShardCount = 5;
final int nodeCount = randomIntBetween(maxReplicaCount + 1, 10);
final int numberOfIndices = randomIntBetween(1, 10);

logger.info("--> Creating {} nodes", nodeCount);
final List<String> nodeNames = new ArrayList<>();
for (int i = 0; i < nodeCount; i++) {
nodeNames.add(internalCluster().startNode());
}
enablePreferPrimaryBalance();

int shardCount, replicaCount, totalShardCount = 0, totalReplicaCount = 0;
ClusterState state;
for (int i = 0; i < numberOfIndices; i++) {
shardCount = randomIntBetween(1, maxShardCount);
totalShardCount += shardCount;
replicaCount = randomIntBetween(1, maxReplicaCount);
totalReplicaCount += replicaCount;
logger.info("--> Creating index test{} with primary {} and replica {}", i, shardCount, replicaCount);
createIndex("test" + i, shardCount, replicaCount, i % 2 == 0);
ensureGreen(TimeValue.timeValueSeconds(60));
if (logger.isTraceEnabled()) {
state = client().admin().cluster().prepareState().execute().actionGet().getState();
logger.info(ShardAllocations.printShardDistribution(state));
}
}
state = client().admin().cluster().prepareState().execute().actionGet().getState();
logger.info(ShardAllocations.printShardDistribution(state));
verifyPerIndexPrimaryBalance();

final int additionalNodeCount = randomIntBetween(1, 5);
logger.info("--> Adding {} nodes", additionalNodeCount);

internalCluster().startNodes(additionalNodeCount);
ensureGreen(TimeValue.timeValueSeconds(60));
state = client().admin().cluster().prepareState().execute().actionGet().getState();
logger.info(ShardAllocations.printShardDistribution(state));
verifyPerIndexPrimaryBalance();

logger.info("--> Stop one third nodes");
for (int i = 0; i < nodeCount; i += 3) {
internalCluster().stopRandomNode(InternalTestCluster.nameFilter(nodeNames.get(i)));
// give replica a chance to promote as primary before terminating node containing the replica
ensureGreen(TimeValue.timeValueSeconds(60));
}
state = client().admin().cluster().prepareState().execute().actionGet().getState();
logger.info(ShardAllocations.printShardDistribution(state));
verifyPerIndexPrimaryBalance();
}

/**
* Utility method which ensures cluster has balanced primary shard distribution across a single index.
* @throws Exception
*/
private void verifyPerIndexPrimaryBalance() throws Exception {
assertBusy(() -> {
final ClusterState currentState = client().admin().cluster().prepareState().execute().actionGet().getState();
RoutingNodes nodes = currentState.getRoutingNodes();
for (ObjectObjectCursor<String, IndexRoutingTable> index : currentState.getRoutingTable().indicesRouting()) {
final int totalPrimaryShards = index.value.primaryShardsActive();
final int avgPrimaryShardsPerNode = (int) Math.ceil(totalPrimaryShards * 1f / currentState.getRoutingNodes().size());
for (RoutingNode node : nodes) {
final int primaryCount = node.shardsWithState(index.key, STARTED)
.stream()
.filter(ShardRouting::primary)
.collect(Collectors.toList())
.size();
assertTrue(primaryCount <= avgPrimaryShardsPerNode);
}
}
}, 60, TimeUnit.SECONDS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,58 +8,48 @@
import org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator;
import org.opensearch.cluster.routing.allocation.allocator.ShardsBalancer;

import java.util.ArrayList;
import java.util.List;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Predicate;

import static org.opensearch.cluster.routing.allocation.RebalanceConstraints.PREFER_PRIMARY_SHARD_BALANCE_NODE_BREACH_ID;
import static org.opensearch.cluster.routing.allocation.RebalanceConstraints.isPrimaryShardsPerIndexPerNodeBreached;

/**
* Allocation constraints specify conditions which, if breached, reduce the
* priority of a node for receiving shard allocations.
* priority of a node for receiving unassigned shard allocations.
*
* @opensearch.internal
*/
public class AllocationConstraints {
public final long CONSTRAINT_WEIGHT = 1000000L;
private List<Predicate<ConstraintParams>> constraintPredicates;

/**
*
* This constraint is only applied for unassigned shards to avoid overloading a newly added node.
* Weight calculation in other scenarios like shard movement and re-balancing remain unaffected by this constraint.
*/
public final static String INDEX_SHARD_PER_NODE_BREACH_CONSTRAINT_ID = "index.shard.breach.constraint";
private Map<String, Constraint> constraints;

public AllocationConstraints() {
dreamer-89 marked this conversation as resolved.
Show resolved Hide resolved
this.constraintPredicates = new ArrayList<>(1);
this.constraintPredicates.add(isIndexShardsPerNodeBreached());
this.constraints = new HashMap<>();
this.constraints.putIfAbsent(
INDEX_SHARD_PER_NODE_BREACH_CONSTRAINT_ID,
new Constraint(INDEX_SHARD_PER_NODE_BREACH_CONSTRAINT_ID, isIndexShardsPerNodeBreached())
);
this.constraints.putIfAbsent(
PREFER_PRIMARY_SHARD_BALANCE_NODE_BREACH_ID,
new Constraint(PREFER_PRIMARY_SHARD_BALANCE_NODE_BREACH_ID, isPrimaryShardsPerIndexPerNodeBreached())
);
}

class ConstraintParams {
private ShardsBalancer balancer;
private BalancedShardsAllocator.ModelNode node;
private String index;

ConstraintParams(ShardsBalancer balancer, BalancedShardsAllocator.ModelNode node, String index) {
this.balancer = balancer;
this.node = node;
this.index = index;
}
public void updateAllocationConstraint(String constraint, boolean enable) {
this.constraints.get(constraint).setEnable(enable);
}

/**
* Evaluates configured allocation constraint predicates for given node - index
* combination; and returns a weight value based on the number of breached
* constraints.
*
* Constraint weight should be added to the weight calculated via weight
* function, to reduce priority of allocating on nodes with breached
* constraints.
*
* This weight function is used only in case of unassigned shards to avoid overloading a newly added node.
* Weight calculation in other scenarios like shard movement and re-balancing remain unaffected by this function.
*/
public long weight(ShardsBalancer balancer, BalancedShardsAllocator.ModelNode node, String index) {
int constraintsBreached = 0;
ConstraintParams params = new ConstraintParams(balancer, node, index);
for (Predicate<ConstraintParams> predicate : constraintPredicates) {
if (predicate.test(params)) {
constraintsBreached++;
}
}
return constraintsBreached * CONSTRAINT_WEIGHT;
Constraint.ConstraintParams params = new Constraint.ConstraintParams(balancer, node, index);
return params.weight(constraints);
}

/**
Expand All @@ -76,12 +66,11 @@ public long weight(ShardsBalancer balancer, BalancedShardsAllocator.ModelNode no
* This constraint is breached when balancer attempts to allocate more than
* average shards per index per node.
*/
private Predicate<ConstraintParams> isIndexShardsPerNodeBreached() {
public static Predicate<Constraint.ConstraintParams> isIndexShardsPerNodeBreached() {
return (params) -> {
int currIndexShardsOnNode = params.node.numShards(params.index);
int allowedIndexShardsPerNode = (int) Math.ceil(params.balancer.avgShardsPerNode(params.index));
int currIndexShardsOnNode = params.getNode().numShards(params.getIndex());
int allowedIndexShardsPerNode = (int) Math.ceil(params.getBalancer().avgShardsPerNode(params.getIndex()));
return (currIndexShardsOnNode >= allowedIndexShardsPerNode);
};
}

}
Loading