Skip to content

Commit

Permalink
Separate WorkLoadManager
Browse files Browse the repository at this point in the history
This separates the WorkLoadManagement functionality from the main djl-serving
into a separate module.
The module first enables isolated usage without serving.

It also helps convert the wlm to focus on CompletableFutures rather than netty.
This makes it easier to leverage from pure Java. As a particular case, it also
enables Translators to leverage the wlm.
  • Loading branch information
zachgk committed Sep 29, 2021
1 parent e459aa7 commit 9465621
Show file tree
Hide file tree
Showing 36 changed files with 944 additions and 348 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ You can install extra extensions to enable the following models:
DJL serving is built on top of [Deep Java Library](https://djl.ai). You can visit
[DJL github repository](https://github.com/deepjavalibrary/djl) to learn more about DJL.

It is also possible to leverage only the worker thread pool using the separate [WorkLoadManager](wlm) module.
The separate WorkLoadManager can be used by customers who want to take advantage of DJL serving's model batching
and threading but integrated into their own custom Java service.

![Architecture Diagram](https://resources.djl.ai/images/djl-serving/architecture.png)

## Key features
Expand Down
1 change: 1 addition & 0 deletions serving/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ plugins {
}

dependencies {
api project(":wlm")
api platform("ai.djl:bom:${project.version}")
api "ai.djl:api"
api "io.netty:netty-all:${netty_version}"
Expand Down
2 changes: 1 addition & 1 deletion serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
package ai.djl.serving;

import ai.djl.repository.FilenameUtils;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.plugins.FolderScanPluginManager;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.Connector;
import ai.djl.serving.util.ServerGroups;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import ai.djl.util.cuda.CudaUtils;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import ai.djl.ModelException;
import ai.djl.modality.Input;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Job;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
Expand Down Expand Up @@ -179,25 +179,7 @@ private void predict(
ConfigManager.getInstance().getMaxBatchDelay(),
ConfigManager.getInstance().getMaxIdleTime())
.thenApply(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, -1)))
.thenAccept(
m -> {
try {
if (!modelManager.addJob(new Job(ctx, m, input))) {
throw new ServiceUnavailableException(
"No worker is available to serve request: "
+ modelName);
}
} catch (ModelNotFoundException e) {
logger.warn("Unexpected error", e);
NettyUtils.sendError(ctx, e);
}
})
.exceptionally(
t -> {
logger.warn("Unexpected error", t);
NettyUtils.sendError(ctx, t);
return null;
});
.thenApply(m -> modelManager.runJob(ctx, new Job(m, input)));
return;
}

Expand All @@ -206,11 +188,6 @@ private void predict(
return;
}

Job job = new Job(ctx, model, input);
if (!modelManager.addJob(job)) {
logger.error("unable to process prediction. no free worker available.");
throw new ServiceUnavailableException(
"No worker is available to serve request: " + modelName);
}
modelManager.runJob(ctx, new Job(model, input));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import ai.djl.ModelException;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.models.Endpoint;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Endpoint;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.wlm;
package ai.djl.serving.models;

import ai.djl.serving.wlm.ModelInfo;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.wlm;
package ai.djl.serving.models;

import ai.djl.Device;
import ai.djl.ModelException;
Expand All @@ -22,6 +22,19 @@
import ai.djl.serving.http.BadRequestException;
import ai.djl.serving.http.DescribeModelResponse;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Job;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkLoadManager;
import ai.djl.serving.wlm.WorkerThread;
import ai.djl.serving.wlm.util.WlmCapacityException;
import ai.djl.serving.wlm.util.WlmShutdownException;
import ai.djl.translate.TranslateException;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -236,14 +249,68 @@ public Set<String> getStartupModels() {
}

/**
* Adds an inference job to the job queue. Assign the job to the next free worker.
* Runs an inference job by assigning the job to the next free worker.
*
* @param ctx the netty channel handler context where the job response will be sent
* @param job an inference job to be executed
* @return {@code true} if submit success
* @throws ModelNotFoundException if the model is not registered
* @return a future for after the netty response was sent
*/
public boolean addJob(Job job) throws ModelNotFoundException {
return wlm.addJob(job);
@SuppressWarnings("PMD.InvalidLogMessageFormat")
public CompletableFuture<Output> runJob(ChannelHandlerContext ctx, Job job) {
return wlm.runJob(job)
.whenComplete(
(output, throwable) -> {
if (throwable != null) {
HttpResponseStatus status;
if (throwable instanceof TranslateException) {
status = HttpResponseStatus.BAD_REQUEST;
} else if (throwable instanceof WlmShutdownException) {
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
logger.error("Unable to process prediction. Worker shutdown");
} else if (throwable instanceof WlmCapacityException) {
logger.error(
"Unable to process prediction. Worker capacity exceeded");
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
} else {
logger.warn("Unexpected error", throwable);
status = HttpResponseStatus.INTERNAL_SERVER_ERROR;
}

/*
* We can load the models based on the configuration file.Since this Job is
* not driven by the external connections, we could have a empty context for
* this job. We shouldn't try to send a response to ctx if this is not triggered
* by external clients.
*/
if (ctx != null) {
NettyUtils.sendError(ctx, status, throwable);
}
} else { // Handle output
FullHttpResponse resp =
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false);
for (Map.Entry<String, String> entry :
output.getProperties().entrySet()) {
resp.headers().set(entry.getKey(), entry.getValue());
}
resp.content().writeBytes(output.getContent());

/*
* We can load the models based on the configuration file.Since this Job is
* not driven by the external connections, we could have a empty context for
* this job. We shouldn't try to send a response to ctx if this is not triggered
* by external clients.
*/
if (ctx != null) {
NettyUtils.sendHttpResponse(ctx, resp, true);
}
}

logger.debug(
"Waiting time: {}, Backend time: {}",
job.getScheduled() - job.getBegin(),
System.currentTimeMillis() - job.getScheduled());
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
* and limitations under the License.
*/
/** Contains classes that manage model lifecycle. */
package ai.djl.serving.wlm;
package ai.djl.serving.models;
36 changes: 4 additions & 32 deletions serving/src/main/java/ai/djl/serving/util/ConfigManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
*/
package ai.djl.serving.util;

import ai.djl.Device;
import ai.djl.ndarray.NDManager;
import ai.djl.serving.Arguments;
import ai.djl.serving.wlm.util.WlmConfigManager;
import ai.djl.util.Utils;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.SelfSignedCertificate;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
Expand Down Expand Up @@ -127,6 +125,9 @@ public static void init(Arguments args) {
"log4j2.contextSelector",
"org.apache.logging.log4j.core.async.AsyncLoggerContextSelector");
}

// Sets corresponding config in the WlmConfigManager
WlmConfigManager.getInstance().setDebug(instance.isDebug());
}

/**
Expand Down Expand Up @@ -209,35 +210,6 @@ public int getMaxBatchDelay() {
return getIntProperty(MAX_BATCH_DELAY, 300);
}

/**
* Returns the default number of workers for a new registered model.
*
* @param manager the {@code NDManager} the model uses
* @param target the target number of worker
* @return the default number of workers for a new registered model
*/
public int getDefaultWorkers(NDManager manager, int target) {
if (target == 0) {
return 0;
} else if (target == -1 && isDebug()) {
return 1;
}
if (Device.Type.GPU.equals(manager.getDevice().getDeviceType())) {
if ("MXNet".equals(manager.getEngine().getEngineName())) {
// FIXME: MXNet GPU Model doesn't support multi-threading
return 1;
} else if (target == -1) {
target = 2; // default to max 2 workers per GPU
}
return target;
}

if (target > 0) {
return target;
}
return Runtime.getRuntime().availableProcessors();
}

/**
* Returns the model server home directory.
*
Expand Down
Loading

0 comments on commit 9465621

Please sign in to comment.