Skip to content

Commit

Permalink
Add per task thread config
Browse files Browse the repository at this point in the history
  • Loading branch information
jxu-nflx committed Jan 19, 2022
1 parent 50bb10f commit f384e1b
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public class ClientProperties {

private Map<String, String> taskToDomain = new HashMap<>();

private Map<String, Integer> taskThreadCount = new HashMap<>();

private int shutdownGracePeriodSeconds = 10;

public String getRootUri() {
Expand Down Expand Up @@ -90,4 +92,12 @@ public int getShutdownGracePeriodSeconds() {
public void setShutdownGracePeriodSeconds(int shutdownGracePeriodSeconds) {
this.shutdownGracePeriodSeconds = shutdownGracePeriodSeconds;
}

public Map<String, Integer> getTaskThreadCount() {
return taskThreadCount;
}

public void setTaskThreadCount(Map<String, Integer> taskThreadCount) {
this.taskThreadCount = taskThreadCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public TaskClient taskClient(ClientProperties clientProperties) {
public TaskRunnerConfigurer taskRunnerConfigurer(
TaskClient taskClient, ClientProperties clientProperties) {
return new TaskRunnerConfigurer.Builder(taskClient, workers)
.withTaskThreadCount(clientProperties.getTaskThreadCount())
.withThreadCount(clientProperties.getThreadCount())
.withSleepWhenRetry((int) clientProperties.getSleepWhenRetryDuration().toMillis())
.withUpdateRetryCount(clientProperties.getUpdateRetryCount())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -51,7 +52,7 @@ class TaskPollExecutor {
private final TaskClient taskClient;
private final int updateRetryCount;
private final ExecutorService executorService;
private final PollingSemaphore pollingSemaphore;
private final Map<String, PollingSemaphore> pollingSemaphoreMap;
private final Map<String /*taskType*/, String /*domain*/> taskToDomain;

private static final String DOMAIN = "domain";
Expand All @@ -64,23 +65,36 @@ class TaskPollExecutor {
int threadCount,
int updateRetryCount,
Map<String, String> taskToDomain,
String workerNamePrefix) {
String workerNamePrefix,
Map<String, Integer> taskThreadCount) {
this.eurekaClient = eurekaClient;
this.taskClient = taskClient;
this.updateRetryCount = updateRetryCount;
this.taskToDomain = taskToDomain;

LOGGER.info("Initialized the TaskPollExecutor with {} threads", threadCount);
this.pollingSemaphoreMap = new HashMap<>();
int totalThreadCount = 0;
if (!taskThreadCount.isEmpty()) {
for (Map.Entry<String, Integer> entry : taskThreadCount.entrySet()) {
String taskType = entry.getKey();
int count = entry.getValue();
totalThreadCount += count;
pollingSemaphoreMap.put(taskType, new PollingSemaphore(count));
}
} else {
totalThreadCount = threadCount;
// shared poll for all workers
pollingSemaphoreMap.put(ALL_WORKERS, new PollingSemaphore(threadCount));
}

LOGGER.info("Initialized the TaskPollExecutor with {} threads", totalThreadCount);
this.executorService =
Executors.newFixedThreadPool(
threadCount,
totalThreadCount,
new BasicThreadFactory.Builder()
.namingPattern(workerNamePrefix)
.uncaughtExceptionHandler(uncaughtExceptionHandler)
.build());

this.pollingSemaphore = new PollingSemaphore(threadCount);
}

void pollAndExecute(Worker worker) {
Expand All @@ -106,13 +120,15 @@ void pollAndExecute(Worker worker) {
return;
}

String taskType = worker.getTaskDefName();
PollingSemaphore pollingSemaphore = getPollingSemaphore(taskType);

Task task;
try {
if (!pollingSemaphore.canPoll()) {
return;
}

String taskType = worker.getTaskDefName();
String domain =
Optional.ofNullable(PropertyFactory.getString(taskType, DOMAIN, null))
.orElseGet(
Expand Down Expand Up @@ -141,7 +157,7 @@ void pollAndExecute(Worker worker) {

CompletableFuture<Task> taskCompletableFuture =
CompletableFuture.supplyAsync(
() -> processTask(task, worker), executorService);
() -> processTask(task, worker, pollingSemaphore), executorService);

taskCompletableFuture.whenComplete(this::finalizeTask);
} else {
Expand Down Expand Up @@ -181,7 +197,7 @@ void shutdownExecutorService(ExecutorService executorService, int timeout) {
LOGGER.error("Uncaught exception. Thread {} will exit now", thread, error);
};

private Task processTask(Task task, Worker worker) {
private Task processTask(Task task, Worker worker, PollingSemaphore pollingSemaphore) {
LOGGER.debug(
"Executing task: {} of type: {} in worker: {} at {}",
task.getTaskId(),
Expand Down Expand Up @@ -317,4 +333,12 @@ private void handleException(Throwable t, TaskResult result, Worker worker, Task

updateWithRetry(updateRetryCount, task, result, worker);
}

private PollingSemaphore getPollingSemaphore(String taskType) {
if (pollingSemaphoreMap.containsKey(taskType)) {
return pollingSemaphoreMap.get(taskType);
} else {
return pollingSemaphoreMap.get(ALL_WORKERS);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.netflix.conductor.client.exception.ConductorClientException;
import com.netflix.conductor.client.http.TaskClient;
import com.netflix.conductor.client.worker.Worker;
import com.netflix.discovery.EurekaClient;
Expand All @@ -28,6 +32,11 @@

/** Configures automated polling of tasks and execution via the registered {@link Worker}s. */
public class TaskRunnerConfigurer {
private static final Logger LOGGER = LoggerFactory.getLogger(TaskRunnerConfigurer.class);
private static final String INVALID_THREAD_COUNT =
"Invalid worker thread count specified, use either shared thread pool or config thread count per task";
private static final String MISSING_TASK_THREAD_COUNT =
"Missing task thread count config for %s";

private ScheduledExecutorService scheduledExecutorService;

Expand All @@ -40,6 +49,7 @@ public class TaskRunnerConfigurer {
private final int shutdownGracePeriodSeconds;
private final String workerNamePrefix;
private final Map<String /*taskType*/, String /*domain*/> taskToDomain;
private final Map<String /*taskType*/, Integer /*threadCount*/> taskThreadCount;

private TaskPollExecutor taskPollExecutor;

Expand All @@ -48,14 +58,34 @@ public class TaskRunnerConfigurer {
* @see TaskRunnerConfigurer#init()
*/
private TaskRunnerConfigurer(Builder builder) {
// only allow either shared thread pool or per task thread pool
if (builder.threadCount != -1 && !builder.taskThreadCount.isEmpty()) {
LOGGER.error(INVALID_THREAD_COUNT);
throw new ConductorClientException(INVALID_THREAD_COUNT);
} else if (!builder.taskThreadCount.isEmpty()) {
for (Worker worker : builder.workers) {
if (!builder.taskThreadCount.containsKey(worker.getTaskDefName())) {
String message =
String.format(MISSING_TASK_THREAD_COUNT, worker.getTaskDefName());
LOGGER.error(message);
throw new ConductorClientException(message);
}
workers.add(worker);
}
this.taskThreadCount = builder.taskThreadCount;
this.threadCount = -1;
} else {
builder.workers.forEach(workers::add);
this.taskThreadCount = builder.taskThreadCount;
this.threadCount = (builder.threadCount == -1) ? workers.size() : builder.threadCount;
}

this.eurekaClient = builder.eurekaClient;
this.taskClient = builder.taskClient;
this.sleepWhenRetry = builder.sleepWhenRetry;
this.updateRetryCount = builder.updateRetryCount;
this.workerNamePrefix = builder.workerNamePrefix;
this.taskToDomain = builder.taskToDomain;
builder.workers.forEach(workers::add);
this.threadCount = (builder.threadCount == -1) ? workers.size() : builder.threadCount;
this.shutdownGracePeriodSeconds = builder.shutdownGracePeriodSeconds;
}

Expand All @@ -71,6 +101,7 @@ public static class Builder {
private EurekaClient eurekaClient;
private final TaskClient taskClient;
private Map<String /*taskType*/, String /*domain*/> taskToDomain = new HashMap<>();
private Map<String /*taskType*/, Integer /*threadCount*/> taskThreadCount = new HashMap<>();

public Builder(TaskClient taskClient, Iterable<Worker> workers) {
Preconditions.checkNotNull(taskClient, "TaskClient cannot be null");
Expand Down Expand Up @@ -151,6 +182,11 @@ public Builder withTaskToDomain(Map<String, String> taskToDomain) {
return this;
}

public Builder withTaskThreadCount(Map<String, Integer> taskThreadCount) {
this.taskThreadCount = taskThreadCount;
return this;
}

/**
* Builds an instance of the TaskRunnerConfigurer.
*
Expand All @@ -162,11 +198,16 @@ public TaskRunnerConfigurer build() {
}
}

/** @return Thread Count for the executor pool */
/** @return Thread Count for the shared executor pool */
public int getThreadCount() {
return threadCount;
}

/** @return Thread Count for individual task type */
public Map<String, Integer> getTaskThreadCount() {
return taskThreadCount;
}

/** @return seconds before forcing shutdown of worker */
public int getShutdownGracePeriodSeconds() {
return shutdownGracePeriodSeconds;
Expand Down Expand Up @@ -204,7 +245,8 @@ public synchronized void init() {
threadCount,
updateRetryCount,
taskToDomain,
workerNamePrefix);
workerNamePrefix,
taskThreadCount);

this.scheduledExecutorService = Executors.newScheduledThreadPool(workers.size());
workers.forEach(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ public void testTaskExecutionException() {
});
TaskClient taskClient = Mockito.mock(TaskClient.class);
TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(null, taskClient, 1, 1, new HashMap<>(), "test-worker-%d");
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-%d", new HashMap<>());

when(taskClient.pollTask(any(), any(), any())).thenReturn(testTask());
when(taskClient.ack(any(), any())).thenReturn(true);
Expand Down Expand Up @@ -112,7 +113,8 @@ public TaskResult answer(InvocationOnMock invocation) {

TaskClient taskClient = Mockito.mock(TaskClient.class);
TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(null, taskClient, 1, 1, new HashMap<>(), "test-worker-");
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
when(taskClient.pollTask(any(), any(), any())).thenReturn(task);
when(taskClient.ack(any(), any())).thenReturn(true);
CountDownLatch latch = new CountDownLatch(3);
Expand Down Expand Up @@ -168,7 +170,8 @@ public void testLargePayloadCanFailUpdateWithRetry() {
.evaluateAndUploadLargePayload(any(TaskResult.class), any());

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(null, taskClient, 1, 3, new HashMap<>(), "test-worker-");
new TaskPollExecutor(
null, taskClient, 1, 3, new HashMap<>(), "test-worker-", new HashMap<>());
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -202,7 +205,8 @@ public void testTaskPollException() {
.thenReturn(task);

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(null, taskClient, 1, 1, new HashMap<>(), "test-worker-");
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -237,7 +241,8 @@ public void testTaskPoll() {
when(taskClient.pollTask(any(), any(), any())).thenReturn(new Task()).thenReturn(task);

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(null, taskClient, 1, 1, new HashMap<>(), "test-worker-");
new TaskPollExecutor(
null, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -266,7 +271,8 @@ public void testTaskPollDomain() {
Map<String, String> taskToDomain = new HashMap<>();
taskToDomain.put(TEST_TASK_DEF_NAME, testDomain);
TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(null, taskClient, 1, 1, taskToDomain, "test-worker-");
new TaskPollExecutor(
null, taskClient, 1, 1, taskToDomain, "test-worker-", new HashMap<>());

String workerName = "test-worker";
Worker worker = mock(Worker.class);
Expand Down Expand Up @@ -306,7 +312,8 @@ public void testPollOutOfDiscoveryForTask() {
when(taskClient.pollTask(any(), any(), any())).thenReturn(new Task()).thenReturn(task);

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(client, taskClient, 1, 1, new HashMap<>(), "test-worker-");
new TaskPollExecutor(
client, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -345,7 +352,8 @@ public void testPollOutOfDiscoveryAsDefaultFalseForTask()
when(taskClient.pollTask(any(), any(), any())).thenReturn(new Task()).thenReturn(task);

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(client, taskClient, 1, 1, new HashMap<>(), "test-worker-");
new TaskPollExecutor(
client, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -385,7 +393,8 @@ public void testPollOutOfDiscoveryAsExplicitFalseForTask()
when(taskClient.pollTask(any(), any(), any())).thenReturn(new Task()).thenReturn(task);

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(client, taskClient, 1, 1, new HashMap<>(), "test-worker-");
new TaskPollExecutor(
client, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand Down Expand Up @@ -424,7 +433,8 @@ public void testPollOutOfDiscoveryIsIgnoredWhenDiscoveryIsUp() {
when(taskClient.pollTask(any(), any(), any())).thenReturn(new Task()).thenReturn(task);

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(client, taskClient, 1, 1, new HashMap<>(), "test-worker-");
new TaskPollExecutor(
client, taskClient, 1, 1, new HashMap<>(), "test-worker-", new HashMap<>());
CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
Expand All @@ -446,6 +456,39 @@ public void testPollOutOfDiscoveryIsIgnoredWhenDiscoveryIsUp() {
verify(taskClient).updateTask(any());
}

@Test
public void testTaskThreadCount() {
TaskClient taskClient = Mockito.mock(TaskClient.class);

Map<String, Integer> taskThreadCount = new HashMap<>();
taskThreadCount.put(TEST_TASK_DEF_NAME, 1);

TaskPollExecutor taskPollExecutor =
new TaskPollExecutor(
null, taskClient, -1, 1, new HashMap<>(), "test-worker-", taskThreadCount);

String workerName = "test-worker";
Worker worker = mock(Worker.class);
when(worker.getTaskDefName()).thenReturn(TEST_TASK_DEF_NAME);
when(worker.getIdentity()).thenReturn(workerName);

CountDownLatch latch = new CountDownLatch(1);
doAnswer(
invocation -> {
latch.countDown();
return null;
})
.when(taskClient)
.pollTask(TEST_TASK_DEF_NAME, workerName, null);

Executors.newSingleThreadScheduledExecutor()
.scheduleAtFixedRate(
() -> taskPollExecutor.pollAndExecute(worker), 0, 1, TimeUnit.SECONDS);

Uninterruptibles.awaitUninterruptibly(latch);
verify(taskClient).pollTask(TEST_TASK_DEF_NAME, workerName, null);
}

private Task testTask() {
Task task = new Task();
task.setTaskId(UUID.randomUUID().toString());
Expand Down
Loading

0 comments on commit f384e1b

Please sign in to comment.