diff --git a/README.md b/README.md index eb82ea7509..71e0d646dd 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,10 @@ You can install extra extensions to enable the following models: DJL serving is built on top of [Deep Java Library](https://djl.ai). You can visit [DJL github repository](https://github.com/deepjavalibrary/djl) to learn more about DJL. +It is also possible to leverage only the worker thread pool using the separate [WorkLoadManager](wlm) module. +The separate WorkLoadManager can be used by customers who want to take advantage of DJL serving's model batching +and threading but integrated into their own custom Java service. + ![Architecture Diagram](https://resources.djl.ai/images/djl-serving/architecture.png) ## Key features diff --git a/serving/build.gradle b/serving/build.gradle index 876fda5d22..89a192414c 100644 --- a/serving/build.gradle +++ b/serving/build.gradle @@ -4,6 +4,7 @@ plugins { } dependencies { + api project(":wlm") api platform("ai.djl:bom:${project.version}") api "ai.djl:api" api "io.netty:netty-all:${netty_version}" diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index f952ecb66f..37f808970c 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -13,12 +13,12 @@ package ai.djl.serving; import ai.djl.repository.FilenameUtils; +import ai.djl.serving.models.ModelManager; import ai.djl.serving.plugins.FolderScanPluginManager; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.Connector; import ai.djl.serving.util.ServerGroups; import ai.djl.serving.wlm.ModelInfo; -import ai.djl.serving.wlm.ModelManager; import ai.djl.util.cuda.CudaUtils; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; diff --git a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java index bbce5fa97e..27ae660a5c 100644 --- a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java @@ -15,11 +15,11 @@ import ai.djl.ModelException; import ai.djl.modality.Input; import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.serving.models.ModelManager; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.NettyUtils; import ai.djl.serving.wlm.Job; import ai.djl.serving.wlm.ModelInfo; -import ai.djl.serving.wlm.ModelManager; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; @@ -179,25 +179,7 @@ private void predict( ConfigManager.getInstance().getMaxBatchDelay(), ConfigManager.getInstance().getMaxIdleTime()) .thenApply(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, -1))) - .thenAccept( - m -> { - try { - if (!modelManager.addJob(new Job(ctx, m, input))) { - throw new ServiceUnavailableException( - "No worker is available to serve request: " - + modelName); - } - } catch (ModelNotFoundException e) { - logger.warn("Unexpected error", e); - NettyUtils.sendError(ctx, e); - } - }) - .exceptionally( - t -> { - logger.warn("Unexpected error", t); - NettyUtils.sendError(ctx, t); - return null; - }); + .thenApply(m -> modelManager.runJob(ctx, new Job(m, input))); return; } @@ -206,11 +188,6 @@ private void predict( return; } - Job job = new Job(ctx, model, input); - if (!modelManager.addJob(job)) { - logger.error("unable to process prediction. no free worker available."); - throw new ServiceUnavailableException( - "No worker is available to serve request: " + modelName); - } + modelManager.runJob(ctx, new Job(model, input)); } } diff --git a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java index 7815b85e86..2e710bd2a2 100644 --- a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java @@ -14,10 +14,10 @@ import ai.djl.ModelException; import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.serving.models.Endpoint; +import ai.djl.serving.models.ModelManager; import ai.djl.serving.util.NettyUtils; -import ai.djl.serving.wlm.Endpoint; import ai.djl.serving.wlm.ModelInfo; -import ai.djl.serving.wlm.ModelManager; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; diff --git a/serving/src/main/java/ai/djl/serving/wlm/Endpoint.java b/serving/src/main/java/ai/djl/serving/models/Endpoint.java similarity index 98% rename from serving/src/main/java/ai/djl/serving/wlm/Endpoint.java rename to serving/src/main/java/ai/djl/serving/models/Endpoint.java index 7ad2fbde5b..e7e695073a 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/Endpoint.java +++ b/serving/src/main/java/ai/djl/serving/models/Endpoint.java @@ -10,8 +10,9 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.serving.wlm; +package ai.djl.serving.models; +import ai.djl.serving.wlm.ModelInfo; import java.util.ArrayList; import java.util.List; import java.util.Map; diff --git a/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java b/serving/src/main/java/ai/djl/serving/models/ModelManager.java similarity index 72% rename from serving/src/main/java/ai/djl/serving/wlm/ModelManager.java rename to serving/src/main/java/ai/djl/serving/models/ModelManager.java index 539d964d77..2edf7ce99f 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java +++ b/serving/src/main/java/ai/djl/serving/models/ModelManager.java @@ -10,7 +10,7 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.serving.wlm; +package ai.djl.serving.models; import ai.djl.Device; import ai.djl.ModelException; @@ -22,6 +22,19 @@ import ai.djl.serving.http.BadRequestException; import ai.djl.serving.http.DescribeModelResponse; import ai.djl.serving.util.ConfigManager; +import ai.djl.serving.util.NettyUtils; +import ai.djl.serving.wlm.Job; +import ai.djl.serving.wlm.ModelInfo; +import ai.djl.serving.wlm.WorkLoadManager; +import ai.djl.serving.wlm.WorkerThread; +import ai.djl.serving.wlm.util.WlmCapacityException; +import ai.djl.serving.wlm.util.WlmShutdownException; +import ai.djl.translate.TranslateException; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; import java.io.IOException; import java.util.HashSet; import java.util.List; @@ -236,14 +249,68 @@ public Set getStartupModels() { } /** - * Adds an inference job to the job queue. Assign the job to the next free worker. + * Runs an inference job by assigning the job to the next free worker. * + * @param ctx the netty channel handler context where the job response will be sent * @param job an inference job to be executed - * @return {@code true} if submit success - * @throws ModelNotFoundException if the model is not registered + * @return a future for after the netty response was sent */ - public boolean addJob(Job job) throws ModelNotFoundException { - return wlm.addJob(job); + @SuppressWarnings("PMD.InvalidLogMessageFormat") + public CompletableFuture runJob(ChannelHandlerContext ctx, Job job) { + return wlm.runJob(job) + .whenComplete( + (output, throwable) -> { + if (throwable != null) { + HttpResponseStatus status; + if (throwable instanceof TranslateException) { + status = HttpResponseStatus.BAD_REQUEST; + } else if (throwable instanceof WlmShutdownException) { + status = HttpResponseStatus.SERVICE_UNAVAILABLE; + logger.error("Unable to process prediction. Worker shutdown"); + } else if (throwable instanceof WlmCapacityException) { + logger.error( + "Unable to process prediction. Worker capacity exceeded"); + status = HttpResponseStatus.SERVICE_UNAVAILABLE; + } else { + logger.warn("Unexpected error", throwable); + status = HttpResponseStatus.INTERNAL_SERVER_ERROR; + } + + /* + * We can load the models based on the configuration file.Since this Job is + * not driven by the external connections, we could have a empty context for + * this job. We shouldn't try to send a response to ctx if this is not triggered + * by external clients. + */ + if (ctx != null) { + NettyUtils.sendError(ctx, status, throwable); + } + } else { // Handle output + FullHttpResponse resp = + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); + for (Map.Entry entry : + output.getProperties().entrySet()) { + resp.headers().set(entry.getKey(), entry.getValue()); + } + resp.content().writeBytes(output.getContent()); + + /* + * We can load the models based on the configuration file.Since this Job is + * not driven by the external connections, we could have a empty context for + * this job. We shouldn't try to send a response to ctx if this is not triggered + * by external clients. + */ + if (ctx != null) { + NettyUtils.sendHttpResponse(ctx, resp, true); + } + } + + logger.debug( + "Waiting time: {}, Backend time: {}", + job.getScheduled() - job.getBegin(), + System.currentTimeMillis() - job.getScheduled()); + }); } /** diff --git a/serving/src/main/java/ai/djl/serving/wlm/package-info.java b/serving/src/main/java/ai/djl/serving/models/package-info.java similarity index 95% rename from serving/src/main/java/ai/djl/serving/wlm/package-info.java rename to serving/src/main/java/ai/djl/serving/models/package-info.java index a1eb814e56..e8657db53b 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/package-info.java +++ b/serving/src/main/java/ai/djl/serving/models/package-info.java @@ -11,4 +11,4 @@ * and limitations under the License. */ /** Contains classes that manage model lifecycle. */ -package ai.djl.serving.wlm; +package ai.djl.serving.models; diff --git a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java index 9c66c0c0e9..e4bb50878d 100644 --- a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java +++ b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java @@ -12,14 +12,12 @@ */ package ai.djl.serving.util; -import ai.djl.Device; -import ai.djl.ndarray.NDManager; import ai.djl.serving.Arguments; +import ai.djl.serving.wlm.util.WlmConfigManager; import ai.djl.util.Utils; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; - import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; @@ -127,6 +125,9 @@ public static void init(Arguments args) { "log4j2.contextSelector", "org.apache.logging.log4j.core.async.AsyncLoggerContextSelector"); } + + // Sets corresponding config in the WlmConfigManager + WlmConfigManager.getInstance().setDebug(instance.isDebug()); } /** @@ -209,35 +210,6 @@ public int getMaxBatchDelay() { return getIntProperty(MAX_BATCH_DELAY, 300); } - /** - * Returns the default number of workers for a new registered model. - * - * @param manager the {@code NDManager} the model uses - * @param target the target number of worker - * @return the default number of workers for a new registered model - */ - public int getDefaultWorkers(NDManager manager, int target) { - if (target == 0) { - return 0; - } else if (target == -1 && isDebug()) { - return 1; - } - if (Device.Type.GPU.equals(manager.getDevice().getDeviceType())) { - if ("MXNet".equals(manager.getEngine().getEngineName())) { - // FIXME: MXNet GPU Model doesn't support multi-threading - return 1; - } else if (target == -1) { - target = 2; // default to max 2 workers per GPU - } - return target; - } - - if (target > 0) { - return target; - } - return Runtime.getRuntime().availableProcessors(); - } - /** * Returns the model server home directory. * diff --git a/serving/src/main/java/ai/djl/serving/wlm/Job.java b/serving/src/main/java/ai/djl/serving/wlm/Job.java deleted file mode 100644 index 26f7a09f2f..0000000000 --- a/serving/src/main/java/ai/djl/serving/wlm/Job.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.serving.wlm; - -import ai.djl.modality.Input; -import ai.djl.modality.Output; -import ai.djl.serving.util.NettyUtils; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** A class represents an inference job. */ -public class Job { - - private static final Logger logger = LoggerFactory.getLogger(Job.class); - - private ChannelHandlerContext ctx; - - private ModelInfo modelInfo; - private Input input; - private long begin; - private long scheduled; - - /** - * Constructs an new {@code Job} instance. - * - * @param ctx the {@code ChannelHandlerContext} - * @param modelInfo the model to run the job - * @param input the input data - */ - public Job(ChannelHandlerContext ctx, ModelInfo modelInfo, Input input) { - this.ctx = ctx; - this.modelInfo = modelInfo; - this.input = input; - - begin = System.currentTimeMillis(); - scheduled = begin; - } - - /** - * Returns the request id. - * - * @return the request id - */ - public String getRequestId() { - return input.getRequestId(); - } - - /** - * Returns the model that associated with this job. - * - * @return the model that associated with this job - */ - public ModelInfo getModel() { - return modelInfo; - } - - /** - * Returns the input data. - * - * @return the input data - */ - public Input getInput() { - return input; - } - - /** Marks the job has been scheduled. */ - public void setScheduled() { - scheduled = System.currentTimeMillis(); - } - - /** - * Sends the response back to the client. - * - * @param output the output - */ - public void sendOutput(Output output) { - FullHttpResponse resp = - new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); - for (Map.Entry entry : output.getProperties().entrySet()) { - resp.headers().set(entry.getKey(), entry.getValue()); - } - resp.content().writeBytes(output.getContent()); - - /* - * We can load the models based on the configuration file.Since this Job is - * not driven by the external connections, we could have a empty context for - * this job. We shouldn't try to send a response to ctx if this is not triggered - * by external clients. - */ - if (ctx != null) { - NettyUtils.sendHttpResponse(ctx, resp, true); - } - - logger.debug( - "Waiting time: {}, Backend time: {}", - scheduled - begin, - System.currentTimeMillis() - scheduled); - } - - /** - * Sends error to the client. - * - * @param status the HTTP status - * @param error the exception - */ - public void sendError(HttpResponseStatus status, Throwable error) { - /* - * We can load the models based on the configuration file.Since this Job is - * not driven by the external connections, we could have a empty context for - * this job. We shouldn't try to send a response to ctx if this is not triggered - * by external clients. - */ - if (ctx != null) { - NettyUtils.sendError(ctx, status, error); - } - - logger.debug( - "Waiting time: {}, Inference time: {}", - scheduled - begin, - System.currentTimeMillis() - begin); - } -} diff --git a/serving/src/main/java/ai/djl/serving/wlm/ScaleCapacityExceededException.java b/serving/src/main/java/ai/djl/serving/wlm/ScaleCapacityExceededException.java deleted file mode 100644 index 982a81087d..0000000000 --- a/serving/src/main/java/ai/djl/serving/wlm/ScaleCapacityExceededException.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.serving.wlm; - -/** - * Is thrown when capacity of workers is reached during autoscaling. - * - * @author erik.bamberg@web.de - */ -public class ScaleCapacityExceededException extends Exception { - - /** serialVersionUDI for this class cause exceptions are serializable. */ - private static final long serialVersionUID = 1633130362838844091L; - - /** No arguments. */ - public ScaleCapacityExceededException() {} - - /** - * construct using a message. - * - * @param message the message. - */ - public ScaleCapacityExceededException(String message) { - super(message); - } - - /** - * construct using a cause. - * - * @param cause the root cause. - */ - public ScaleCapacityExceededException(Throwable cause) { - super(cause); - } - - /** - * construct using a message and a clause. - * - * @param message the message. - * @param cause the root cause. - */ - public ScaleCapacityExceededException(String message, Throwable cause) { - super(message, cause); - } - - /** - * construct using a message cause and flags. - * - * @param message the message. - * @param cause the root cause. - * @param enableSuppression enable suppression or not. - * @param writableStackTrace flag if writableStackTrace. - */ - public ScaleCapacityExceededException( - String message, - Throwable cause, - boolean enableSuppression, - boolean writableStackTrace) { - super(message, cause, enableSuppression, writableStackTrace); - } -} diff --git a/serving/src/main/puml/architecture.puml b/serving/src/main/puml/architecture.puml index 30d8ef5150..60d3a5c400 100644 --- a/serving/src/main/puml/architecture.puml +++ b/serving/src/main/puml/architecture.puml @@ -26,6 +26,7 @@ package "DJL Serving - single process" { HTTP - REST_API } + package "WorkLoad Manager" as wlm { frame "Worker thread pool" as wp { package Workers [ resent18_v1 (GPU0) @@ -39,8 +40,9 @@ package "DJL Serving - single process" { } queue "Job queue\nauto batch" as jq - [Model Manager] as wlm + } + [Model Manager] as mm frame Models { package Engines [ PyTorch @@ -54,10 +56,10 @@ package "DJL Serving - single process" { } REST_API -> jq - REST_API ---> wlm + REST_API ---> mm jq -> Workers : auto scale - jq ...> wlm - wlm -right-> Engines + jq ...> mm + mm -right-> Engines Engines -[hidden]up- [Translator] Translator <-up- Workers } @@ -66,7 +68,7 @@ frame "Python Workers" { [preprocess] -[hidden]-- [postprocess] } -wlm -down-> URL : load model +mm -down-> URL : load model Translator -up.> preprocess : optional Translator -down.> postprocess : optional @enduml diff --git a/settings.gradle b/settings.gradle index d158bbd3c6..94b449ea07 100644 --- a/settings.gradle +++ b/settings.gradle @@ -3,3 +3,4 @@ include ':central' include ':plugins:plugin-management-plugin' include ':plugins:static-file-plugin' include ':serving' +include ':wlm' diff --git a/wlm/README.md b/wlm/README.md new file mode 100644 index 0000000000..fac74ea145 --- /dev/null +++ b/wlm/README.md @@ -0,0 +1,32 @@ +# DJL Serving - WorkLoadManager + +The djl-serving serving can be divided into a frontend and backend. +The frontend is a [netty](https://netty.io/) webserver that manages incoming requests and operators the control plane. +The backend WorkLoadManager handles the model batching, workers, and threading for high-performance inference. + +For those who already have a web server infrastructure but want to operate high-performance inference, it is possible to use only the WorkLoadManager. +For this reason, we have it split apart into a separate module. + +Using the WorkLoadManager is quite simple. First, create a new one through the constructor: + +```java +WorkLoadManager wlm = new WorkLoadManager(); +``` + +You can also configure the WorkLoadManager by using the static `WlmConfigManager`. + +Then, you can construct an instance of the `ModelInfo` for each model you will want to run through `wlm`. +With the `ModelInfo`, you are able to build a `Job` once you receive input: + +```java +ModelInfo modelInfo = new ModelInfo(...); +Job job = new Job(modelInfo, input); +``` + +Once you have your job, it can be submitted to the WorkLoadManager. +It will automatically spin up workers if none are created and manage worker numbers. +Then, it returns a `CompletableFuture` for the result. + +```java +CompletableFuture futureResult = wlm.runJob(job); +``` diff --git a/wlm/build.gradle b/wlm/build.gradle new file mode 100644 index 0000000000..ae01bc1098 --- /dev/null +++ b/wlm/build.gradle @@ -0,0 +1,101 @@ +plugins { + id "maven-publish" + id "signing" +} + +dependencies { + api platform("ai.djl:bom:${project.version}") + api "ai.djl:api" + api "org.slf4j:slf4j-api:${slf4j_version}" + + testImplementation("org.testng:testng:${testng_version}") { + exclude group: "junit", module: "junit" + } +} + +java { + withJavadocJar() + withSourcesJar() +} + +javadoc { + title "DJL Serving WorkLoadManager ${version}" + options.encoding = "UTF-8" + options.overview "src/main/javadoc/overview.html" + options.addBooleanOption("-allow-script-in-comments", true) +} + +task uploadJavadoc(type: Exec) { + dependsOn javadoc + commandLine "sh", "-c", "find . -name .DS_Store | xargs rm && aws s3 sync build/docs/javadoc s3://javadoc-djl-ai/${project.name}/${version} > build/upload.log" +} + +signing { + required(project.hasProperty("staging") || project.hasProperty("snapshot")) + def signingKey = findProperty("signingKey") + def signingPassword = findProperty("signingPassword") + useInMemoryPgpKeys(signingKey, signingPassword) + sign publishing.publications +} + +publishing { + publications { + maven(MavenPublication) { + from components.java + artifacts = [jar, javadocJar, sourcesJar] + pom { + name = "DJL Serving WorkLoadManager" + description = "DJL Serving WorkLoadManager" + url = "http://www.djl.ai/" + + packaging = "jar" + + licenses { + license { + name = 'The Apache License, Version 2.0' + url = 'https://www.apache.org/licenses/LICENSE-2.0' + } + } + + scm { + connection = "scm:git:git@github.com:deepjavalibrary/djl-serving.git" + developerConnection = "scm:git:git@github.com:deepjavalibrary/djl-serving.git" + url = "https://github.com/deepjavalibrary/djl-serving" + tag = "HEAD" + } + + developers { + developer { + name = "DJL.AI Team" + email = "djl-dev@amazon.com" + organization = "Amazon AI" + organizationUrl = "https://amazon.com" + } + } + } + } + } + + repositories { + maven { + if (project.hasProperty("snapshot")) { + name = "snapshot" + url = "https://oss.sonatype.org/content/repositories/snapshots/" + credentials { + username = findProperty("ossrhUsername") + password = findProperty("ossrhPassword") + } + } else if (project.hasProperty("staging")) { + name = "staging" + url = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" + credentials { + username = findProperty("ossrhUsername") + password = findProperty("ossrhPassword") + } + } else { + name = "local" + url = "build/repo" + } + } + } +} diff --git a/wlm/gradlew b/wlm/gradlew new file mode 100755 index 0000000000..b70597b686 --- /dev/null +++ b/wlm/gradlew @@ -0,0 +1,201 @@ +#!/usr/bin/env sh + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +########################################################################################## +# Extension to allow automatically downloading the gradle-wrapper.jar +# This allows using the maven wrapper in projects that prohibit checking in binary data. +########################################################################################## +WRAPPER_JAR_PATH="$APP_HOME/gradle/wrapper/gradle-wrapper.jar" +if [ ! -r "${WRAPPER_JAR_PATH}" ]; then + jarUrl="https://raw.githubusercontent.com/gradle/gradle/master/gradle/wrapper/gradle-wrapper.jar" + if command -v wget > /dev/null; then + wget -q "${jarUrl}" -O "${WRAPPER_JAR_PATH}" + elif command -v curl > /dev/null; then + curl -s -o "${WRAPPER_JAR_PATH}" "$jarUrl" + else + javaClass="$APP_HOME/gradle/wrapper/GradleWrapperDownloader.java" + if [ -e "$javaClass" ]; then + if [ ! -e "$APP_HOME/gradle/wrapper/GradleWrapperDownloader.class" ]; then + # Compiling the Java class + ("${JAVACMD}c" "$javaClass") + fi + if [ -e "$APP_HOME/gradle/wrapper/GradleWrapperDownloader.class" ]; then + ("$JAVACMD" -cp gradle/wrapper GradleWrapperDownloader "$APP_HOME") + fi + fi + fi +fi +########################################################################################## +# End of extension +########################################################################################## + + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=$(save "$@") + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong +if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then + cd "$(dirname "$0")" +fi + +exec "$JAVACMD" "$@" diff --git a/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java similarity index 75% rename from serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java rename to wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java index 25ed8515d3..2a6bd046d6 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/BatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/BatchAggregator.java @@ -14,7 +14,7 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; -import io.netty.handler.codec.http.HttpResponseStatus; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.ArrayList; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; @@ -30,8 +30,8 @@ abstract class BatchAggregator { protected int batchSize; protected int maxBatchDelay; - protected List jobs; - protected LinkedBlockingDeque jobQueue; + protected List wjs; + protected LinkedBlockingDeque jobQueue; /** * Constructs a new {@code BbatchAggregator} instance. @@ -39,11 +39,11 @@ abstract class BatchAggregator { * @param model the model to use. * @param jobQueue the job queue for polling data from. */ - public BatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { + public BatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { this.batchSize = model.getBatchSize(); this.maxBatchDelay = model.getMaxBatchDelay(); this.jobQueue = jobQueue; - jobs = new ArrayList<>(); + wjs = new ArrayList<>(); } /** @@ -54,9 +54,10 @@ public BatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { * queue. */ public List getRequest() throws InterruptedException { - jobs = pollBatch(); - List list = new ArrayList<>(jobs.size()); - for (Job job : jobs) { + wjs = pollBatch(); + List list = new ArrayList<>(wjs.size()); + for (WorkerJob wj : wjs) { + Job job = wj.getJob(); job.setScheduled(); list.add(job.getInput()); } @@ -69,30 +70,29 @@ public List getRequest() throws InterruptedException { * @param outputs list of model-outputs in same order as the input objects. */ public void sendResponse(List outputs) { - if (jobs.size() != outputs.size()) { + if (wjs.size() != outputs.size()) { throw new IllegalStateException("Not all jobs get response."); } int i = 0; for (Output output : outputs) { - Job job = jobs.get(i++); - output.setRequestId(job.getRequestId()); - job.sendOutput(output); + WorkerJob wj = wjs.get(i++); + output.setRequestId(wj.getJob().getRequestId()); + wj.getFuture().complete(output); } - jobs.clear(); + wjs.clear(); } /** - * Sends an error response to client. + * Completes the job with an error. * - * @param status the HTTP status * @param error the exception */ - public void sendError(HttpResponseStatus status, Throwable error) { - for (Job job : jobs) { - job.sendError(status, error); + public void sendError(Throwable error) { + for (WorkerJob wj : wjs) { + wj.getFuture().completeExceptionally(error); } - jobs.clear(); + wjs.clear(); } /** @@ -101,7 +101,7 @@ public void sendError(HttpResponseStatus status, Throwable error) { * @return a list of jobs read by this batch interation. * @throws InterruptedException if interrupted */ - protected abstract List pollBatch() throws InterruptedException; + protected abstract List pollBatch() throws InterruptedException; /** * Checks if this {@code BatchAggregator} and the thread can be shutdown or if this aggregator @@ -112,19 +112,19 @@ public void sendError(HttpResponseStatus status, Throwable error) { */ public abstract boolean isFinished(); - protected void drainTo(List list, int maxDelay) throws InterruptedException { + protected void drainTo(List list, int maxDelay) throws InterruptedException { long begin = System.currentTimeMillis(); jobQueue.drainTo(list, batchSize - 1); int remain = batchSize - list.size(); for (int i = 0; i < remain; ++i) { - Job job = jobQueue.poll(maxDelay, TimeUnit.MILLISECONDS); - if (job == null) { + WorkerJob wj = jobQueue.poll(maxDelay, TimeUnit.MILLISECONDS); + if (wj == null || wj.getJob() == null) { break; } long end = System.currentTimeMillis(); maxDelay -= end - begin; begin = end; - list.add(job); + list.add(wj); if (maxDelay <= 0) { break; } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/Job.java b/wlm/src/main/java/ai/djl/serving/wlm/Job.java new file mode 100644 index 0000000000..04e5263ba1 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/Job.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm; + +import ai.djl.modality.Input; + +/** A class represents an inference job. */ +public class Job { + + private ModelInfo modelInfo; + private Input input; + private long begin; + private long scheduled; + + /** + * Constructs a new {@code Job} instance. + * + * @param modelInfo the model to run the job + * @param input the input data + */ + public Job(ModelInfo modelInfo, Input input) { + this.modelInfo = modelInfo; + this.input = input; + + begin = System.currentTimeMillis(); + scheduled = begin; + } + + /** + * Returns the request id. + * + * @return the request id + */ + public String getRequestId() { + return input.getRequestId(); + } + + /** + * Returns the model that associated with this job. + * + * @return the model that associated with this job + */ + public ModelInfo getModel() { + return modelInfo; + } + + /** + * Returns the input data. + * + * @return the input data + */ + public Input getInput() { + return input; + } + + /** + * Returns the job begin time. + * + * @return the job begin time + */ + public long getBegin() { + return begin; + } + + /** + * Returns the job scheduled time. + * + * @return the job scheduled time + */ + public long getScheduled() { + return scheduled; + } + + /** Marks the job has been scheduled. */ + public void setScheduled() { + scheduled = System.currentTimeMillis(); + } +} diff --git a/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java similarity index 98% rename from serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java rename to wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index b3ec29bb16..317bf428c6 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -17,7 +17,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.repository.FilenameUtils; import ai.djl.repository.zoo.ZooModel; -import ai.djl.serving.util.ConfigManager; +import ai.djl.serving.wlm.util.WlmConfigManager; import java.net.URI; import java.nio.file.Path; import java.util.Objects; @@ -98,7 +98,7 @@ public ModelInfo configureModelBatch(int batchSize, int maxBatchDelay) { */ public ModelInfo scaleWorkers(int minWorkers, int maxWorkers) { NDManager manager = model.getNDManager(); - ConfigManager configManager = ConfigManager.getInstance(); + WlmConfigManager configManager = WlmConfigManager.getInstance(); this.maxWorkers = configManager.getDefaultWorkers(manager, maxWorkers); this.minWorkers = Math.min(minWorkers, this.maxWorkers); return this; diff --git a/serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java similarity index 83% rename from serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java rename to wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java index b7d5637fed..dec1c36b91 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/PermanentBatchAggregator.java @@ -12,6 +12,7 @@ */ package ai.djl.serving.wlm; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.ArrayList; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; @@ -34,17 +35,17 @@ public class PermanentBatchAggregator extends BatchAggregator { * @param model the model to use. * @param jobQueue the job queue for polling data from. */ - public PermanentBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { + public PermanentBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { super(model, jobQueue); } /** {@inheritDoc} */ @Override - protected List pollBatch() throws InterruptedException { - List list = new ArrayList<>(batchSize); - Job job = jobQueue.take(); - list.add(job); - logger.trace("get first job: {}", job.getRequestId()); + protected List pollBatch() throws InterruptedException { + List list = new ArrayList<>(batchSize); + WorkerJob wj = jobQueue.take(); + list.add(wj); + logger.trace("get first job: {}", wj.getJob().getRequestId()); drainTo(list, maxBatchDelay); logger.trace("sending jobs, size: {}", list.size()); return list; diff --git a/serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java b/wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java similarity index 86% rename from serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java rename to wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java index e3dc134859..1482a5f5ad 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/TemporaryBatchAggregator.java @@ -12,6 +12,7 @@ */ package ai.djl.serving.wlm; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.ArrayList; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; @@ -37,7 +38,7 @@ public class TemporaryBatchAggregator extends BatchAggregator { * @param model the model to run for. * @param jobQueue reference to external job queue for polling. */ - public TemporaryBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { + public TemporaryBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQueue) { super(model, jobQueue); this.idleSince = System.currentTimeMillis(); this.maxIdleTime = model.getMaxIdleTime(); @@ -45,11 +46,11 @@ public TemporaryBatchAggregator(ModelInfo model, LinkedBlockingDeque jobQue /** {@inheritDoc} */ @Override - protected List pollBatch() throws InterruptedException { - List list = new ArrayList<>(batchSize); - Job job = jobQueue.poll(maxIdleTime, TimeUnit.SECONDS); - if (job != null) { - list.add(job); + protected List pollBatch() throws InterruptedException { + List list = new ArrayList<>(batchSize); + WorkerJob wj = jobQueue.poll(maxIdleTime, TimeUnit.SECONDS); + if (wj != null && wj.getJob() != null) { + list.add(wj); drainTo(list, maxBatchDelay); logger.trace("sending jobs, size: {}", list.size()); idleSince = System.currentTimeMillis(); diff --git a/serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java similarity index 87% rename from serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java rename to wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java index 4e702821f7..155db70f99 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java @@ -12,8 +12,13 @@ */ package ai.djl.serving.wlm; +import ai.djl.modality.Output; +import ai.djl.serving.wlm.util.WlmCapacityException; +import ai.djl.serving.wlm.util.WlmShutdownException; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.Collections; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutorService; @@ -24,19 +29,19 @@ import org.slf4j.LoggerFactory; /** - * WorkLoadManager is repsonsible to manage the work load of worker thread. the manage scales + * WorkLoadManager is responsible to manage the work load of worker thread. the manage scales * up/down the required amount of worker threads per model. * * @author erik.bamberg@web.de */ -class WorkLoadManager { +public class WorkLoadManager { private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class); private ExecutorService threadPool; private ConcurrentHashMap workerPools; - /** Constructs a {@code WorkLoadManager} instance. */ + /** Constructs a {@link WorkLoadManager} instance. */ public WorkLoadManager() { threadPool = Executors.newCachedThreadPool(); workerPools = new ConcurrentHashMap<>(); @@ -69,18 +74,27 @@ public List getWorkers(ModelInfo modelInfo) { * @param job an inference job to be executed. * @return {@code true} if submit success, false otherwise. */ - public boolean addJob(Job job) { + public CompletableFuture runJob(Job job) { + CompletableFuture result = new CompletableFuture<>(); ModelInfo modelInfo = job.getModel(); int maxWorkers = modelInfo.getMaxWorkers(); if (maxWorkers == 0) { logger.info("All model workers has been shutdown: {}", modelInfo.getModelName()); - return false; + result.completeExceptionally( + new WlmShutdownException( + "No worker is available to serve request: " + + modelInfo.getModelName())); + return result; } WorkerPool pool = getWorkerPoolForModel(modelInfo); - LinkedBlockingDeque queue = pool.getJobQueue(); - if (!queue.offer(job)) { + LinkedBlockingDeque queue = pool.getJobQueue(); + if (!queue.offer(new WorkerJob(job, result))) { logger.warn("Worker queue capacity exceeded for model: {}", modelInfo.getModelName()); - return false; + result.completeExceptionally( + new WlmCapacityException( + "No worker is available to serve request: " + + modelInfo.getModelName())); + return result; } int currentWorkers = getNumRunningWorkers(modelInfo); @@ -97,7 +111,7 @@ public boolean addJob(Job job) { } } } - return true; + return result; } /** @@ -208,7 +222,7 @@ private void addThreads( private static final class WorkerPool { private List workers; - private LinkedBlockingDeque jobQueue; + private LinkedBlockingDeque jobQueue; private String modelName; /** @@ -236,7 +250,7 @@ public List getWorkers() { * * @return the jobQueue */ - public LinkedBlockingDeque getJobQueue() { + public LinkedBlockingDeque getJobQueue() { return jobQueue; } diff --git a/serving/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java similarity index 100% rename from serving/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java rename to wlm/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java diff --git a/serving/src/main/java/ai/djl/serving/wlm/WorkerState.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerState.java similarity index 100% rename from serving/src/main/java/ai/djl/serving/wlm/WorkerState.java rename to wlm/src/main/java/ai/djl/serving/wlm/WorkerState.java diff --git a/serving/src/main/java/ai/djl/serving/wlm/WorkerThread.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java similarity index 86% rename from serving/src/main/java/ai/djl/serving/wlm/WorkerThread.java rename to wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java index fedc88df3a..7411581b60 100644 --- a/serving/src/main/java/ai/djl/serving/wlm/WorkerThread.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java @@ -12,14 +12,12 @@ */ package ai.djl.serving.wlm; -import ai.djl.engine.EngineException; import ai.djl.inference.Predictor; import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.repository.zoo.ZooModel; -import ai.djl.serving.http.InternalServerException; -import ai.djl.translate.TranslateException; -import io.netty.handler.codec.http.HttpResponseStatus; +import ai.djl.serving.wlm.util.WlmException; +import ai.djl.serving.wlm.util.WorkerJob; import java.util.List; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.atomic.AtomicBoolean; @@ -27,7 +25,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -final class WorkerThread implements Runnable { +/** The {@link WorkerThread} is the worker managed by the {@link WorkLoadManager}. */ +public final class WorkerThread implements Runnable { private static final Logger logger = LoggerFactory.getLogger(WorkerThread.class); @@ -76,12 +75,9 @@ public void run() { try { List reply = predictor.batchPredict(req); aggregator.sendResponse(reply); - } catch (EngineException e) { + } catch (Exception e) { logger.warn("Failed to predict", e); - aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e); - } catch (TranslateException e) { - logger.warn("Failed to predict", e); - aggregator.sendError(HttpResponseStatus.BAD_REQUEST, e); + aggregator.sendError(e); } } req = null; @@ -96,40 +92,70 @@ public void run() { currentThread.set(null); shutdown(WorkerState.WORKER_STOPPED); if (req != null) { - Exception e = new InternalServerException(errorMessage); - aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e); + Exception e = new WlmException(errorMessage); + aggregator.sendError(e); } } } + /** + * Returns the worker thread ID. + * + * @return the worker thread ID + */ public int getWorkerId() { return workerId; } + /** + * Returns true if the worker thread is running. + * + * @return true if the worker thread is running + */ public boolean isRunning() { return running.get(); } + /** + * Returns the gpu id used by the thread. + * + * @return the gpu id used by the thread + */ public int getGpuId() { return gpuId; } + /** + * Returns the thread start time. + * + * @return the thread start time + */ public long getStartTime() { return startTime; } + /** + * Returns the worker state. + * + * @return the worker state + */ public WorkerState getState() { return state; } + /** + * Shuts down the worker thread. + * + * @param state the state to set the thread to + */ public void shutdown(WorkerState state) { running.set(false); setState(state); Thread thread = currentThread.getAndSet(null); if (thread != null) { thread.interrupt(); - Exception e = new InternalServerException("Worker shutting down"); - aggregator.sendError(HttpResponseStatus.INTERNAL_SERVER_ERROR, e); + Exception e = new WlmException("Worker shutting down"); + aggregator.sendError(e); } predictor.close(); } @@ -175,7 +201,7 @@ public static class Builder { private ModelInfo model; private BatchAggregator aggregator; - private LinkedBlockingDeque jobQueue; + private LinkedBlockingDeque jobQueue; private boolean fixPoolThread; Builder() { @@ -253,7 +279,7 @@ public Builder optAggregator(BatchAggregator aggregator) { * @param jobQueue the jobQueue to set * @return self-reference to this builder. */ - public Builder setJobQueue(LinkedBlockingDeque jobQueue) { + public Builder setJobQueue(LinkedBlockingDeque jobQueue) { this.jobQueue = jobQueue; return self(); } diff --git a/wlm/src/main/java/ai/djl/serving/wlm/package-info.java b/wlm/src/main/java/ai/djl/serving/wlm/package-info.java new file mode 100644 index 0000000000..681d0e86e1 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** + * Contains the model server backend which manages worker threads and executes jobs on models. + * + * @see ai.djl.serving.wlm.WorkLoadManager + */ +package ai.djl.serving.wlm; diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java new file mode 100644 index 0000000000..22042c7746 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmCapacityException.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +/** Thrown to throttle when a job is run but the job queue capacity is exceeded. */ +public class WlmCapacityException extends RuntimeException { + + static final long serialVersionUID = 1L; + + /** + * Constructs a {@link WlmCapacityException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public WlmCapacityException(String message) { + super(message); + } + + /** + * Constructs a {@link WlmCapacityException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public WlmCapacityException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java new file mode 100644 index 0000000000..181991becd --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java @@ -0,0 +1,82 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +import ai.djl.Device; +import ai.djl.ndarray.NDManager; + +/** This manages some configurations used by the {@link ai.djl.serving.wlm.WorkLoadManager}. */ +public final class WlmConfigManager { + + private static final WlmConfigManager INSTANCE = new WlmConfigManager(); + + private boolean debug; + + /** + * Returns the singleton {@code ConfigManager} instance. + * + * @return the singleton {@code ConfigManager} instance + */ + public static WlmConfigManager getInstance() { + return INSTANCE; + } + + /** + * Returns if debug is enabled. + * + * @return {@code true} if debug is enabled + */ + public boolean isDebug() { + return debug; + } + + /** + * Sets debug mode. + * + * @param debug true to enable debug mode and false to disable + * @return this config manager + */ + public WlmConfigManager setDebug(boolean debug) { + this.debug = debug; + return this; + } + + /** + * Returns the default number of workers for a new registered model. + * + * @param manager the {@code NDManager} the model uses + * @param target the target number of worker + * @return the default number of workers for a new registered model + */ + public int getDefaultWorkers(NDManager manager, int target) { + if (target == 0) { + return 0; + } else if (target == -1 && isDebug()) { + return 1; + } + if (Device.Type.GPU.equals(manager.getDevice().getDeviceType())) { + if ("MXNet".equals(manager.getEngine().getEngineName())) { + // FIXME: MXNet GPU Model doesn't support multi-threading + return 1; + } else if (target == -1) { + target = 2; // default to max 2 workers per GPU + } + return target; + } + + if (target > 0) { + return target; + } + return Runtime.getRuntime().availableProcessors(); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmException.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmException.java new file mode 100644 index 0000000000..7293205456 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmException.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +/** Thrown when an exception occurs inside the {@link ai.djl.serving.wlm.WorkLoadManager}. */ +public class WlmException extends RuntimeException { + + static final long serialVersionUID = 1L; + + /** + * Constructs a {@link WlmException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public WlmException(String message) { + super(message); + } + + /** + * Constructs a {@link WlmException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public WlmException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java new file mode 100644 index 0000000000..3527fa1a7a --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmShutdownException.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +/** Thrown when a job is run but all workers are shutdown. */ +public class WlmShutdownException extends RuntimeException { + + static final long serialVersionUID = 1L; + + /** + * Constructs a {@link WlmShutdownException} with the specified detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + */ + public WlmShutdownException(String message) { + super(message); + } + + /** + * Constructs a {@link WlmShutdownException} with the specified detail message and cause. + * + *

Note that the detail message associated with {@code cause} is not automatically + * incorporated into this exception's detail message. + * + * @param message The detail message (which is saved for later retrieval by the {@link + * #getMessage()} method) + * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} + * method). (A null value is permitted, and indicates that the cause is nonexistent or + * unknown.) + */ + public WlmShutdownException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WorkerJob.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WorkerJob.java new file mode 100644 index 0000000000..dce9c01a30 --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WorkerJob.java @@ -0,0 +1,53 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm.util; + +import ai.djl.modality.Output; +import ai.djl.serving.wlm.Job; +import java.util.concurrent.CompletableFuture; + +/** A {@link Job} containing metadata from the {@link ai.djl.serving.wlm.WorkLoadManager}. */ +public final class WorkerJob { + + private final Job job; + private final CompletableFuture future; + + /** + * Constructs a new {@link WorkerJob}. + * + * @param job the job to execute + * @param future the future containing the job response + */ + public WorkerJob(Job job, CompletableFuture future) { + this.job = job; + this.future = future; + } + + /** + * Returns the {@link Job}. + * + * @return the {@link Job} + */ + public Job getJob() { + return job; + } + + /** + * Returns the future for the job. + * + * @return the future for the job + */ + public CompletableFuture getFuture() { + return future; + } +} diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/package-info.java b/wlm/src/main/java/ai/djl/serving/wlm/util/package-info.java new file mode 100644 index 0000000000..a1d39eb57d --- /dev/null +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/package-info.java @@ -0,0 +1,14 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +/** Contains utilities to support the {@link ai.djl.serving.wlm.WorkLoadManager}. */ +package ai.djl.serving.wlm.util; diff --git a/wlm/src/main/javadoc/overview.html b/wlm/src/main/javadoc/overview.html new file mode 100644 index 0000000000..a7ef6f1731 --- /dev/null +++ b/wlm/src/main/javadoc/overview.html @@ -0,0 +1,14 @@ + + + + + +

This document is the API specification for the DJL Serving WorkLoadManager.

+ +

+ This module provides the worker and thread management for a high-performance inference server. + See here for more details. +

+ + + diff --git a/serving/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java similarity index 100% rename from serving/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java rename to wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java diff --git a/serving/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java b/wlm/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java similarity index 100% rename from serving/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java rename to wlm/src/test/java/ai/djl/serving/wlm/WorkerIdGeneratorTest.java diff --git a/serving/src/test/java/ai/djl/serving/wlm/package-info.java b/wlm/src/test/java/ai/djl/serving/wlm/package-info.java similarity index 100% rename from serving/src/test/java/ai/djl/serving/wlm/package-info.java rename to wlm/src/test/java/ai/djl/serving/wlm/package-info.java