Skip to content

Commit

Permalink
[serving] Adds mutliple node cluster configuration support (deepjaval…
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Jul 18, 2024
1 parent 46d82d2 commit aab6756
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 6 deletions.
1 change: 1 addition & 0 deletions serving/docker/config.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8080
cluster_address=http://0.0.0.0:8888
model_store=/opt/ml/model
load_models=ALL
#model_url_pattern=.*
42 changes: 40 additions & 2 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.plugins.DependencyManager;
import ai.djl.serving.plugins.FolderScanPluginManager;
import ai.djl.serving.util.ClusterConfig;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.Connector;
import ai.djl.serving.util.ServerGroups;
Expand Down Expand Up @@ -194,7 +195,6 @@ public List<ChannelFuture> start()
GeneralSecurityException,
ServerStartupException {
long begin = System.nanoTime();
stopped.set(false);

String version = Engine.getDjlVersion();
logger.info("Starting djl-serving: {} ...", version);
Expand All @@ -204,13 +204,16 @@ public List<ChannelFuture> start()

pluginManager.loadPlugins(true);

initMultiNode();

try {
initModelStore();
} catch (BadWorkflowException | CompletionException e) {
throw new ServerStartupException(
"Failed to initialize startup models and workflows", e);
}

stopped.set(false);
Connector inferenceConnector =
configManager.getConnector(Connector.ConnectorType.INFERENCE);
Connector managementConnector =
Expand Down Expand Up @@ -273,6 +276,41 @@ public void stop() {
serverGroups.reset();
}

private void initMultiNode()
throws GeneralSecurityException,
IOException,
InterruptedException,
ServerStartupException {
ClusterConfig cc = ClusterConfig.getInstance();
int clusterSize = cc.getClusterSize();
if (clusterSize > 1) {
Connector multiNodeConnector =
configManager.getConnector(Connector.ConnectorType.CLUSTER);
multiNodeConnector.clean();

EventLoopGroup serverGroup = serverGroups.getServerGroup();
EventLoopGroup workerGroup = serverGroups.getChildGroup();

ChannelFuture future = initializeServer(multiNodeConnector, serverGroup, workerGroup);

// start download model here
cc.countDown();

logger.info("Waiting for all worker nodes ready ...");
cc.await();

future.channel().close();
serverGroups.shutdown(true);
serverGroups.reset();

String status = cc.getError();
if (status != null) {
throw new ServerStartupException("Failed to initialize cluster: " + status);
}
logger.info("Cluster initialized with {} nodes.", clusterSize);
}
}

private ChannelFuture initializeServer(
Connector connector, EventLoopGroup serverGroup, EventLoopGroup workerGroup)
throws InterruptedException, IOException, GeneralSecurityException {
Expand Down Expand Up @@ -486,7 +524,7 @@ String mapModelUrl(Path path) {
} catch (MalformedURLException e) {
throw new AssertionError("Invalid path: " + path, e);
} catch (IOException e) {
logger.warn("Failed to access file: " + path, e);
logger.warn("Failed to access file: {}", path, e);
return null;
}
}
Expand Down
4 changes: 4 additions & 0 deletions serving/src/main/java/ai/djl/serving/ServerInitializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.serving;

import ai.djl.serving.http.AdapterManagementRequestHandler;
import ai.djl.serving.http.ClusterRequestHandler;
import ai.djl.serving.http.ConfigurableHttpRequestHandler;
import ai.djl.serving.http.InferenceRequestHandler;
import ai.djl.serving.http.InvalidRequestHandler;
Expand Down Expand Up @@ -74,6 +75,9 @@ public void initChannel(Channel ch) {
case INFERENCE:
pipeline.addLast("inference", new InferenceRequestHandler());
break;
case CLUSTER:
pipeline.addLast("cluster", new ClusterRequestHandler());
break;
case BOTH:
default:
pipeline.addLast(new ConfigurableHttpRequestHandler(pluginManager));
Expand Down
105 changes: 105 additions & 0 deletions serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.http;

import ai.djl.ModelException;
import ai.djl.serving.util.ClusterConfig;
import ai.djl.serving.util.NettyUtils;
import ai.djl.util.Utils;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.QueryStringDecoder;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;

/** A class handling inbound HTTP requests for the cluster management API. */
public class ClusterRequestHandler extends HttpRequestHandler {

private static final Logger logger = LoggerFactory.getLogger(ClusterRequestHandler.class);

private ClusterConfig config = ClusterConfig.getInstance();

/** {@inheritDoc} */
@Override
public boolean acceptInboundMessage(Object msg) throws Exception {
if (super.acceptInboundMessage(msg)) {
FullHttpRequest req = (FullHttpRequest) msg;
return req.uri().startsWith("/cluster/");
}
return false;
}

/** {@inheritDoc} */
@Override
protected void handleRequest(
ChannelHandlerContext ctx,
FullHttpRequest req,
QueryStringDecoder decoder,
String[] segments)
throws ModelException {
switch (segments[2]) {
case "sshkey":
Path home = Paths.get(System.getProperty("user.home")).resolve(".ssh");
Path file = home.resolve("id_rsa.pub");
if (Files.notExists(file)) {
sshkeygen(home.resolve("id_rsa").toString());
}
NettyUtils.sendFile(ctx, file, false);
return;
case "status":
List<String> messages = decoder.parameters().get("message");
if (messages.size() != 1) {
NettyUtils.sendJsonResponse(ctx, new StatusResponse("Invalid request"));
return;
} else if (!"OK".equals(messages.get(0))) {
config.setError(messages.get(0));
}
config.countDown();
NettyUtils.sendJsonResponse(ctx, new StatusResponse("OK"));
return;
default:
throw new ResourceNotFoundException();
}
}

private void sshkeygen(String rsaFile) {
try {
String[] commands = {"ssh-keygen", "-q", "-t", "rsa", "-N", "''", "-f", rsaFile};
Process exec = new ProcessBuilder(commands).redirectErrorStream(true).start();
String logOutput;
try (InputStream is = exec.getInputStream()) {
logOutput = Utils.toString(is);
}
int exitCode = exec.waitFor();
if (0 != exitCode) {
logger.error("Generate ssh key failed: {}", logOutput);
config.setError(logOutput);
throw new IllegalStateException("Generate ssh key failed");
} else {
logger.debug(logOutput);
}
} catch (IOException | InterruptedException e) {
config.setError("Generate ssh key failed");
throw new IllegalStateException("Generate ssh key failed", e);
}
}
}
87 changes: 87 additions & 0 deletions serving/src/main/java/ai/djl/serving/util/ClusterConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.util;

import ai.djl.util.Utils;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/** A class that holds cluster configurations. */
public final class ClusterConfig {

private static final ClusterConfig INSTANCE = new ClusterConfig();

private int clusterSize;
private CountDownLatch latch;
private String error;

private ClusterConfig() {
clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE", "1"));
latch = new CountDownLatch(clusterSize);
}

/**
* Returns the {@code ClusterConfig} singleton object.
*
* @return the {@code ClusterConfig} singleton object
*/
public static ClusterConfig getInstance() {
return INSTANCE;
}

/**
* Returns the cluster size.
*
* @return the cluster size
*/
public int getClusterSize() {
return clusterSize;
}

/**
* Returns the error status message.
*
* @return the error status message
*/
public String getError() {
return error;
}

/**
* Sets the error status message.
*
* @param error the error status message
*/
public void setError(String error) {
this.error = error;
}

/** Decreases the number of waiting workers. */
public void countDown() {
latch.countDown();
}

/**
* Causes current threads to wait until all workers are ready.
*
* @throws InterruptedException if current thread is interrupted
*/
public void await() throws InterruptedException {
// TODO: support per model timeout
int timeout = Integer.parseInt(Utils.getenv("MODEL_LOADING_TIMEOUT", "240"));
if (!latch.await(timeout, TimeUnit.SECONDS)) {
error = "Worker nodes timed out";
}
}
}
17 changes: 13 additions & 4 deletions serving/src/main/java/ai/djl/serving/util/ConfigManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public final class ConfigManager {

private static final String INFERENCE_ADDRESS = "inference_address";
private static final String MANAGEMENT_ADDRESS = "management_address";
private static final String CLUSTER_ADDRESS = "cluster_address";
private static final String LOAD_MODELS = "load_models";
private static final String WAIT_MODEL_LOADING = "wait_model_loading";
private static final String ALLOW_MULTI_STATUS = "allow_multi_status";
Expand Down Expand Up @@ -236,10 +237,18 @@ private boolean onError(String key) {
*/
public Connector getConnector(Connector.ConnectorType type) {
String binding;
if (type == Connector.ConnectorType.MANAGEMENT) {
binding = prop.getProperty(MANAGEMENT_ADDRESS, "http://127.0.0.1:8080");
} else {
binding = prop.getProperty(INFERENCE_ADDRESS, "http://127.0.0.1:8080");
switch (type) {
case MANAGEMENT:
binding = prop.getProperty(MANAGEMENT_ADDRESS, "http://127.0.0.1:8080");
break;
case CLUSTER:
binding = prop.getProperty(CLUSTER_ADDRESS, "http://127.0.0.1:8888");
break;
case INFERENCE:
case BOTH:
default:
binding = prop.getProperty(INFERENCE_ADDRESS, "http://127.0.0.1:8080");
break;
}
return Connector.parse(binding, type);
}
Expand Down
1 change: 1 addition & 0 deletions serving/src/main/java/ai/djl/serving/util/Connector.java
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ public String toString() {

/** An enum represents type of connector. */
public enum ConnectorType {
CLUSTER,
INFERENCE,
MANAGEMENT,
BOTH
Expand Down

0 comments on commit aab6756

Please sign in to comment.