Skip to content

Commit

Permalink
Fixing the CPU Intensive RemoveAll with Lists in Sticky & Load Based …
Browse files Browse the repository at this point in the history
…Partition Assignment Strategies (#965)

* Fixing the CPU Intensive RemoveAll with Lists in Sticky & Load Based Partition Assignment Strategies

* Use new HashSet to avoid modifications on unmodifiable collection

---------

Co-authored-by: Shrinand Thakkar <sthakkar@sthakkar-mn2.linkedin.biz>
  • Loading branch information
shrinandthakkar and Shrinand Thakkar authored Oct 27, 2023
1 parent 178956c commit 0faec8e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public LoadBasedPartitionAssigner(int defaultPartitionBytesInKBRate, int default
*/
public Map<String, Set<DatastreamTask>> assignPartitions(
ClusterThroughputInfo throughputInfo, Map<String, Set<DatastreamTask>> currentAssignment,
List<String> unassignedPartitions, DatastreamGroupPartitionsMetadata partitionMetadata, int maxPartitionsPerTask) {
Set<String> unassignedPartitions, DatastreamGroupPartitionsMetadata partitionMetadata, int maxPartitionsPerTask) {
String datastreamGroupName = partitionMetadata.getDatastreamGroup().getName();
LOG.info("START: assignPartitions for datasteam={}", datastreamGroupName);
Map<String, PartitionThroughputInfo> partitionInfoMap = new HashMap<>(throughputInfo.getPartitionInfoMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -107,17 +108,17 @@ public Map<String, Set<DatastreamTask>> assignPartitions(Map<String, Set<Datastr
}

String datastreamGroupName = datastreamGroup.getName();
Pair<List<String>, Integer> assignedPartitionsAndTaskCount = getAssignedPartitionsAndTaskCountForDatastreamGroup(
Pair<Set<String>, Integer> assignedPartitionsAndTaskCount = getAssignedPartitionsAndTaskCountForDatastreamGroup(
currentAssignment, datastreamGroupName);
List<String> assignedPartitions = assignedPartitionsAndTaskCount.getKey();
Set<String> assignedPartitions = assignedPartitionsAndTaskCount.getKey();
int taskCount = assignedPartitionsAndTaskCount.getValue();
LOG.info("Old partition assignment info, assignment: {}", currentAssignment);
Validate.isTrue(taskCount > 0, String.format("No tasks found for datastream group %s", datastreamGroup));
Validate.isTrue(currentAssignment.size() > 0,
"Zero tasks assigned. Retry leader partition assignment");

// Calculating unassigned partitions
List<String> unassignedPartitions = new ArrayList<>(datastreamPartitions.getPartitions());
Set<String> unassignedPartitions = new HashSet<>(datastreamPartitions.getPartitions());
unassignedPartitions.removeAll(assignedPartitions);

ClusterThroughputInfo clusterThroughputInfo = new ClusterThroughputInfo(StringUtils.EMPTY, Collections.emptyMap());
Expand Down Expand Up @@ -192,7 +193,7 @@ public Map<String, Set<DatastreamTask>> assignPartitions(Map<String, Set<Datastr

@VisibleForTesting
Map<String, Set<DatastreamTask>> doAssignment(ClusterThroughputInfo clusterThroughputInfo,
Map<String, Set<DatastreamTask>> currentAssignment, List<String> unassignedPartitions,
Map<String, Set<DatastreamTask>> currentAssignment, Set<String> unassignedPartitions,
DatastreamGroupPartitionsMetadata datastreamPartitions) {
Map<String, Set<DatastreamTask>> assignment = _assigner.assignPartitions(
clusterThroughputInfo, currentAssignment, unassignedPartitions, datastreamPartitions, _maxPartitionPerTask);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package com.linkedin.datastream.server.assignment;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -51,13 +50,13 @@ public LoadBasedTaskCountEstimator(int taskCapacityMBps, int taskCapacityUtiliza
* Gets the estimated number of tasks based on per-partition throughput information.
* NOTE: This does not take into account numPartitionsPerTask configuration
* @param throughputInfo Per-partition throughput information
* @param assignedPartitions The list of assigned partitions
* @param unassignedPartitions The list of unassigned partitions
* @param assignedPartitions The set of assigned partitions
* @param unassignedPartitions The set of unassigned partitions
* @param datastreamName Name of the datastream
* @return The estimated number of tasks
*/
public int getTaskCount(ClusterThroughputInfo throughputInfo, List<String> assignedPartitions,
List<String> unassignedPartitions, String datastreamName) {
public int getTaskCount(ClusterThroughputInfo throughputInfo, Set<String> assignedPartitions,
Set<String> unassignedPartitions, String datastreamName) {
Validate.notNull(throughputInfo, "null throughputInfo");
Validate.notNull(assignedPartitions, "null assignedPartitions");
Validate.notNull(unassignedPartitions, "null unassignedPartitions");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ public Integer getPartitionsPerTask(DatastreamGroup datastreamGroup) {
return resolveConfigWithMetadata(datastreamGroup, CFG_PARTITIONS_PER_TASK, _partitionsPerTask);
}

protected Pair<List<String>, Integer> getAssignedPartitionsAndTaskCountForDatastreamGroup(
protected Pair<Set<String>, Integer> getAssignedPartitionsAndTaskCountForDatastreamGroup(
Map<String, Set<DatastreamTask>> currentAssignment, String datastreamGroupName) {
List<String> assignedPartitions = new ArrayList<>();
Set<String> assignedPartitions = new HashSet<>();
int taskCount = 0;
for (Set<DatastreamTask> tasks : currentAssignment.values()) {
Set<DatastreamTask> dgTask = tasks.stream().filter(t -> datastreamGroupName.equals(t.getTaskPrefix()))
Expand Down Expand Up @@ -218,9 +218,9 @@ public Map<String, Set<DatastreamTask>> assignPartitions(Map<String,
String dgName = datastreamGroup.getName();

// Step 1: collect the # of tasks and figured out the unassigned partitions
Pair<List<String>, Integer> assignedPartitionsAndTaskCount =
Pair<Set<String>, Integer> assignedPartitionsAndTaskCount =
getAssignedPartitionsAndTaskCountForDatastreamGroup(currentAssignment, dgName);
List<String> assignedPartitions = assignedPartitionsAndTaskCount.getKey();
Set<String> assignedPartitions = assignedPartitionsAndTaskCount.getKey();
int totalTaskCount = assignedPartitionsAndTaskCount.getValue();
Validate.isTrue(totalTaskCount > 0, String.format("No tasks found for datastream group %s", dgName));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
*/
package com.linkedin.datastream.server;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.HashSet;
import java.util.Set;

import org.testng.Assert;
import org.testng.annotations.BeforeClass;
Expand Down Expand Up @@ -42,8 +42,8 @@ public void setup() {
@Test
public void emptyAssignmentReturnsZeroTasksTest() {
ClusterThroughputInfo throughputInfo = _provider.getThroughputInfo("pizza");
List<String> assignedPartitions = Collections.emptyList();
List<String> unassignedPartitions = Collections.emptyList();
Set<String> assignedPartitions = Collections.emptySet();
Set<String> unassignedPartitions = Collections.emptySet();
LoadBasedTaskCountEstimator estimator = new LoadBasedTaskCountEstimator(TASK_CAPACITY_MBPS,
TASK_CAPACITY_UTILIZATION_PCT, DEFAULT_BYTES_IN_KB_RATE, DEFAULT_MSGS_IN_RATE);
int taskCount = estimator.getTaskCount(throughputInfo, assignedPartitions, unassignedPartitions, "test");
Expand All @@ -53,9 +53,9 @@ public void emptyAssignmentReturnsZeroTasksTest() {
@Test
public void lowThroughputAssignmentReturnsOneTaskTest() {
ClusterThroughputInfo throughputInfo = _provider.getThroughputInfo("pizza");
List<String> assignedPartitions = new ArrayList<>();
Set<String> assignedPartitions = new HashSet<>();
assignedPartitions.add("Pepperoni-1");
List<String> unassignedPartitions = Collections.emptyList();
Set<String> unassignedPartitions = Collections.emptySet();
LoadBasedTaskCountEstimator estimator = new LoadBasedTaskCountEstimator(TASK_CAPACITY_MBPS,
TASK_CAPACITY_UTILIZATION_PCT, DEFAULT_BYTES_IN_KB_RATE, DEFAULT_MSGS_IN_RATE);
int taskCount = estimator.getTaskCount(throughputInfo, assignedPartitions, unassignedPartitions, "test");
Expand All @@ -65,8 +65,8 @@ public void lowThroughputAssignmentReturnsOneTaskTest() {
@Test
public void highThroughputAssignmentTest() {
ClusterThroughputInfo throughputInfo = _provider.getThroughputInfo("ice-cream");
List<String> assignedPartitions = Collections.emptyList();
List<String> unassignedPartitions = new ArrayList<>(throughputInfo.getPartitionInfoMap().keySet());
Set<String> assignedPartitions = Collections.emptySet();
Set<String> unassignedPartitions = throughputInfo.getPartitionInfoMap().keySet();
LoadBasedTaskCountEstimator estimator = new LoadBasedTaskCountEstimator(TASK_CAPACITY_MBPS,
TASK_CAPACITY_UTILIZATION_PCT, DEFAULT_BYTES_IN_KB_RATE, DEFAULT_MSGS_IN_RATE);
int taskCount = estimator.getTaskCount(throughputInfo, assignedPartitions, unassignedPartitions, "test");
Expand All @@ -81,8 +81,8 @@ public void highThroughputAssignmentTest() {
@Test
public void highThroughputAssignmentTest2() {
ClusterThroughputInfo throughputInfo = _provider.getThroughputInfo("donut");
List<String> assignedPartitions = Collections.emptyList();
List<String> unassignedPartitions = new ArrayList<>(throughputInfo.getPartitionInfoMap().keySet());
Set<String> assignedPartitions = Collections.emptySet();
Set<String> unassignedPartitions = throughputInfo.getPartitionInfoMap().keySet();
LoadBasedTaskCountEstimator estimator = new LoadBasedTaskCountEstimator(TASK_CAPACITY_MBPS,
TASK_CAPACITY_UTILIZATION_PCT, DEFAULT_BYTES_IN_KB_RATE, DEFAULT_MSGS_IN_RATE);
int taskCount = estimator.getTaskCount(throughputInfo, assignedPartitions, unassignedPartitions, "test");
Expand All @@ -92,8 +92,8 @@ public void highThroughputAssignmentTest2() {
@Test
public void partitionsHaveDefaultWeightTest() {
ClusterThroughputInfo throughputInfo = new ClusterThroughputInfo("dummy", new HashMap<>());
List<String> assignedPartitions = Collections.emptyList();
List<String> unassignedPartitions = Arrays.asList("P1", "P2");
Set<String> assignedPartitions = Collections.emptySet();
Set<String> unassignedPartitions = new HashSet<>(Arrays.asList("P1", "P2"));
LoadBasedTaskCountEstimator estimator = new LoadBasedTaskCountEstimator(TASK_CAPACITY_MBPS,
TASK_CAPACITY_UTILIZATION_PCT, DEFAULT_BYTES_IN_KB_RATE, DEFAULT_MSGS_IN_RATE);
int taskCount = estimator.getTaskCount(throughputInfo, assignedPartitions, unassignedPartitions, "test");
Expand All @@ -103,8 +103,8 @@ public void partitionsHaveDefaultWeightTest() {
@Test
public void throughputTaskEstimatorWithTopicLevelInformation() {
ClusterThroughputInfo throughputInfo = _provider.getThroughputInfo("fruit");
List<String> assignedPartitions = Collections.emptyList();
List<String> unassignedPartitions = Arrays.asList("apple-0", "apple-1", "apple-2", "banana-0");
Set<String> assignedPartitions = Collections.emptySet();
Set<String> unassignedPartitions = new HashSet<>(Arrays.asList("apple-0", "apple-1", "apple-2", "banana-0"));
LoadBasedTaskCountEstimator estimator = new LoadBasedTaskCountEstimator(TASK_CAPACITY_MBPS,
TASK_CAPACITY_UTILIZATION_PCT, DEFAULT_BYTES_IN_KB_RATE, DEFAULT_MSGS_IN_RATE);
int taskCount = estimator.getTaskCount(throughputInfo, assignedPartitions, unassignedPartitions, "test");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package com.linkedin.datastream.server.assignment;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -56,7 +55,7 @@ public void setup() {

@Test
public void assignFromScratchTest() {
List<String> unassignedPartitions = Arrays.asList("P1", "P2", "P3");
Set<String> unassignedPartitions = new HashSet<>(Arrays.asList("P1", "P2", "P3"));
ClusterThroughputInfo throughputInfo = getDummyClusterThroughputInfo(unassignedPartitions);

Datastream ds1 = DatastreamTestUtils.createDatastreams(DummyConnector.CONNECTOR_TYPE, "ds1")[0];
Expand Down Expand Up @@ -107,9 +106,9 @@ public void assignFromScratchTest() {

@Test
public void newAssignmentRetainsTasksFromOtherDatastreamsTest() {
List<String> assignedPartitions = Arrays.asList("P1", "P2");
List<String> unassignedPartitions = Collections.singletonList("P3");
List<String> allPartitions = new ArrayList<>(assignedPartitions);
Set<String> assignedPartitions = new HashSet<>(Arrays.asList("P1", "P2"));
Set<String> unassignedPartitions = Collections.singleton("P3");
Set<String> allPartitions = new HashSet<>(assignedPartitions);
allPartitions.addAll(unassignedPartitions);
ClusterThroughputInfo throughputInfo = getDummyClusterThroughputInfo(allPartitions);

Expand Down Expand Up @@ -169,7 +168,7 @@ public void newAssignmentRetainsTasksFromOtherDatastreamsTest() {
@Test
public void assignmentDistributesPartitionsWhenThroughputInfoIsMissingTest() {
// this tests the round-robin assignment of partitions that don't have throughput info
List<String> unassignedPartitions = Arrays.asList("P1", "P2", "P3", "P4");
Set<String> unassignedPartitions = new HashSet<>(Arrays.asList("P1", "P2", "P3", "P4"));
ClusterThroughputInfo throughputInfo = new ClusterThroughputInfo("dummy", new HashMap<>());

Datastream ds1 = DatastreamTestUtils.createDatastreams(DummyConnector.CONNECTOR_TYPE, "ds1")[0];
Expand Down Expand Up @@ -208,7 +207,7 @@ public void assignmentDistributesPartitionsWhenThroughputInfoIsMissingTest() {

@Test
public void lightestTaskGetsNewPartitionTest() {
List<String> unassignedPartitions = Collections.singletonList("P4");
Set<String> unassignedPartitions = Collections.singleton("P4");
Map<String, PartitionThroughputInfo> throughputInfoMap = new HashMap<>();
throughputInfoMap.put("P1", new PartitionThroughputInfo(5, 5, "P1"));
throughputInfoMap.put("P2", new PartitionThroughputInfo(5, 5, "P2"));
Expand Down Expand Up @@ -246,7 +245,7 @@ public void lightestTaskGetsNewPartitionTest() {

@Test
public void lightestTaskGetsNewPartitionWithTopicMetricsTest() {
List<String> unassignedPartitions = Arrays.asList("P-2", "P-3");
Set<String> unassignedPartitions = new HashSet<>(Arrays.asList("P-2", "P-3"));
Map<String, PartitionThroughputInfo> throughputInfoMap = new HashMap<>();
throughputInfoMap.put("P-1", new PartitionThroughputInfo(5, 5, "P-1"));
throughputInfoMap.put("R", new PartitionThroughputInfo(5, 5, "R"));
Expand Down Expand Up @@ -288,7 +287,7 @@ public void lightestTaskGetsNewPartitionWithTopicMetricsTest() {

@Test
public void throwsExceptionWhenNotEnoughRoomForAllPartitionsTest() {
List<String> unassignedPartitions = Arrays.asList("P4", "P5");
Set<String> unassignedPartitions = new HashSet<>(Arrays.asList("P4", "P5"));
Map<String, PartitionThroughputInfo> throughputInfoMap = new HashMap<>();
ClusterThroughputInfo throughputInfo = new ClusterThroughputInfo("dummy", throughputInfoMap);

Expand All @@ -312,7 +311,7 @@ public void throwsExceptionWhenNotEnoughRoomForAllPartitionsTest() {

@Test
public void taskWithRoomGetsNewPartitionTest() {
List<String> unassignedPartitions = Collections.singletonList("P4");
Set<String> unassignedPartitions = Collections.singleton("P4");
Map<String, PartitionThroughputInfo> throughputInfoMap = new HashMap<>();
throughputInfoMap.put("P1", new PartitionThroughputInfo(5, 5, "P1"));
throughputInfoMap.put("P2", new PartitionThroughputInfo(5, 5, "P2"));
Expand Down Expand Up @@ -393,7 +392,7 @@ private DatastreamTask createTaskForDatastream(Datastream datastream, List<Strin
return task;
}

private ClusterThroughputInfo getDummyClusterThroughputInfo(List<String> partitions) {
private ClusterThroughputInfo getDummyClusterThroughputInfo(Set<String> partitions) {
Map<String, PartitionThroughputInfo> partitionThroughputMap = new HashMap<>();
for (String partitionName : partitions) {
int bytesInRate = 5;
Expand Down

0 comments on commit 0faec8e

Please sign in to comment.