Skip to content

Commit

Permalink
Separate WorkLoadManager (#11)
Browse files Browse the repository at this point in the history
* Separate WorkLoadManager

This separates the WorkLoadManagement functionality from the main djl-serving
into a separate module.
The module first enables isolated usage without serving.

It also helps convert the wlm to focus on CompletableFutures rather than netty.
This makes it easier to leverage from pure Java. As a particular case, it also
enables Translators to leverage the wlm.

* Update gradle build script

* Move anonymous block into function

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
zachgk and frankfliu authored Sep 30, 2021
1 parent e9e6c7e commit 11a9132
Show file tree
Hide file tree
Showing 37 changed files with 749 additions and 350 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ You can install extra extensions to enable the following models:
DJL serving is built on top of [Deep Java Library](https://djl.ai). You can visit
[DJL github repository](https://github.com/deepjavalibrary/djl) to learn more about DJL.

It is also possible to leverage only the worker thread pool using the separate [WorkLoadManager](wlm) module.
The separate WorkLoadManager can be used by customers who want to take advantage of DJL serving's model batching
and threading but integrated into their own custom Java service.

![Architecture Diagram](https://resources.djl.ai/images/djl-serving/architecture.png)

## Key features
Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def djl_version = System.getenv("DJL_VERSION")
djl_version = (djl_version == null) ? "0.13.0-SNAPSHOT" : djl_version

allprojects {
group 'ai.djl'
group 'ai.djl.serving'
version "${djl_version}"

repositories {
Expand Down
3 changes: 1 addition & 2 deletions serving/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ plugins {

dependencies {
api platform("ai.djl:bom:${project.version}")
api "ai.djl:api"
api project(":wlm")
api "io.netty:netty-all:${netty_version}"
api "org.slf4j:slf4j-api:${slf4j_version}"

//noinspection GradlePackageUpdate
implementation "commons-cli:commons-cli:${commons_cli_version}"
Expand Down
2 changes: 1 addition & 1 deletion serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
package ai.djl.serving;

import ai.djl.repository.FilenameUtils;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.plugins.FolderScanPluginManager;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.Connector;
import ai.djl.serving.util.ServerGroups;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import ai.djl.util.cuda.CudaUtils;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import ai.djl.ModelException;
import ai.djl.modality.Input;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Job;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
Expand Down Expand Up @@ -179,25 +179,7 @@ private void predict(
ConfigManager.getInstance().getMaxBatchDelay(),
ConfigManager.getInstance().getMaxIdleTime())
.thenApply(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, -1)))
.thenAccept(
m -> {
try {
if (!modelManager.addJob(new Job(ctx, m, input))) {
throw new ServiceUnavailableException(
"No worker is available to serve request: "
+ modelName);
}
} catch (ModelNotFoundException e) {
logger.warn("Unexpected error", e);
NettyUtils.sendError(ctx, e);
}
})
.exceptionally(
t -> {
logger.warn("Unexpected error", t);
NettyUtils.sendError(ctx, t);
return null;
});
.thenAccept(m -> modelManager.runJob(ctx, new Job(m, input)));
return;
}

Expand All @@ -206,11 +188,6 @@ private void predict(
return;
}

Job job = new Job(ctx, model, input);
if (!modelManager.addJob(job)) {
logger.error("unable to process prediction. no free worker available.");
throw new ServiceUnavailableException(
"No worker is available to serve request: " + modelName);
}
modelManager.runJob(ctx, new Job(model, input));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import ai.djl.ModelException;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.models.Endpoint;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Endpoint;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.wlm;
package ai.djl.serving.models;

import ai.djl.serving.wlm.ModelInfo;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.wlm;
package ai.djl.serving.models;

import ai.djl.Device;
import ai.djl.ModelException;
Expand All @@ -22,6 +22,19 @@
import ai.djl.serving.http.BadRequestException;
import ai.djl.serving.http.DescribeModelResponse;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Job;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkLoadManager;
import ai.djl.serving.wlm.WorkerThread;
import ai.djl.serving.wlm.util.WlmCapacityException;
import ai.djl.serving.wlm.util.WlmShutdownException;
import ai.djl.translate.TranslateException;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -236,14 +249,25 @@ public Set<String> getStartupModels() {
}

/**
* Adds an inference job to the job queue. Assign the job to the next free worker.
* Runs an inference job by assigning the job to the next free worker.
*
* @param ctx the netty channel handler context where the job response will be sent
* @param job an inference job to be executed
* @return {@code true} if submit success
* @throws ModelNotFoundException if the model is not registered
*/
public boolean addJob(Job job) throws ModelNotFoundException {
return wlm.addJob(job);
public void runJob(ChannelHandlerContext ctx, Job job) {
wlm.runJob(job)
.whenComplete(
(o, t) -> {
if (t != null) {
onException(t, ctx);
} else {
sendOutput(o, ctx);
}
logger.trace(
"Waiting time: {}, Backend time: {}",
job.getScheduled() - job.getBegin(),
System.currentTimeMillis() - job.getScheduled());
});
}

/**
Expand Down Expand Up @@ -315,4 +339,49 @@ public CompletableFuture<String> workerStatus() {
return response;
});
}

void sendOutput(Output output, ChannelHandlerContext ctx) {
FullHttpResponse resp =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false);
for (Map.Entry<String, String> entry : output.getProperties().entrySet()) {
resp.headers().set(entry.getKey(), entry.getValue());
}
resp.content().writeBytes(output.getContent());

/*
* We can load the models based on the configuration file.Since this Job is
* not driven by the external connections, we could have a empty context for
* this job. We shouldn't try to send a response to ctx if this is not triggered
* by external clients.
*/
if (ctx != null) {
NettyUtils.sendHttpResponse(ctx, resp, true);
}
}

void onException(Throwable t, ChannelHandlerContext ctx) {
HttpResponseStatus status;
if (t instanceof TranslateException) {
status = HttpResponseStatus.BAD_REQUEST;
} else if (t instanceof WlmShutdownException) {
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
logger.error("Unable to process prediction. Worker shutdown");
} else if (t instanceof WlmCapacityException) {
logger.error("Unable to process prediction. Worker capacity exceeded");
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
} else {
logger.warn("Unexpected error", t);
status = HttpResponseStatus.INTERNAL_SERVER_ERROR;
}

/*
* We can load the models based on the configuration file.Since this Job is
* not driven by the external connections, we could have a empty context for
* this job. We shouldn't try to send a response to ctx if this is not triggered
* by external clients.
*/
if (ctx != null) {
NettyUtils.sendError(ctx, status, t);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
* and limitations under the License.
*/
/** Contains classes that manage model lifecycle. */
package ai.djl.serving.wlm;
package ai.djl.serving.models;
35 changes: 4 additions & 31 deletions serving/src/main/java/ai/djl/serving/util/ConfigManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
*/
package ai.djl.serving.util;

import ai.djl.Device;
import ai.djl.ndarray.NDManager;
import ai.djl.serving.Arguments;
import ai.djl.serving.wlm.util.WlmConfigManager;
import ai.djl.util.Utils;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
Expand Down Expand Up @@ -126,6 +125,9 @@ public static void init(Arguments args) {
"log4j2.contextSelector",
"org.apache.logging.log4j.core.async.AsyncLoggerContextSelector");
}

// Sets corresponding config in the WlmConfigManager
WlmConfigManager.getInstance().setDebug(instance.isDebug());
}

/**
Expand Down Expand Up @@ -208,35 +210,6 @@ public int getMaxBatchDelay() {
return getIntProperty(MAX_BATCH_DELAY, 300);
}

/**
* Returns the default number of workers for a new registered model.
*
* @param manager the {@code NDManager} the model uses
* @param target the target number of worker
* @return the default number of workers for a new registered model
*/
public int getDefaultWorkers(NDManager manager, int target) {
if (target == 0) {
return 0;
} else if (target == -1 && isDebug()) {
return 1;
}
if (Device.Type.GPU.equals(manager.getDevice().getDeviceType())) {
if ("MXNet".equals(manager.getEngine().getEngineName())) {
// FIXME: MXNet GPU Model doesn't support multi-threading
return 1;
} else if (target == -1) {
target = 2; // default to max 2 workers per GPU
}
return target;
}

if (target > 0) {
return target;
}
return Runtime.getRuntime().availableProcessors();
}

/**
* Returns the model server home directory.
*
Expand Down
Loading

0 comments on commit 11a9132

Please sign in to comment.