From aab675666dd8cffd7c3cd996ca98fb1273b3346c Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 18 Jul 2024 15:18:45 -0700 Subject: [PATCH] [serving] Adds mutliple node cluster configuration support (#2190) --- serving/docker/config.properties | 1 + .../main/java/ai/djl/serving/ModelServer.java | 42 ++++++- .../ai/djl/serving/ServerInitializer.java | 4 + .../serving/http/ClusterRequestHandler.java | 105 ++++++++++++++++++ .../ai/djl/serving/util/ClusterConfig.java | 87 +++++++++++++++ .../ai/djl/serving/util/ConfigManager.java | 17 ++- .../java/ai/djl/serving/util/Connector.java | 1 + 7 files changed, 251 insertions(+), 6 deletions(-) create mode 100644 serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java create mode 100644 serving/src/main/java/ai/djl/serving/util/ClusterConfig.java diff --git a/serving/docker/config.properties b/serving/docker/config.properties index 13fd32f4e..c0b30ef99 100644 --- a/serving/docker/config.properties +++ b/serving/docker/config.properties @@ -1,5 +1,6 @@ inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8080 +cluster_address=http://0.0.0.0:8888 model_store=/opt/ml/model load_models=ALL #model_url_pattern=.* diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index 9e5afc64d..d9bc956a9 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -25,6 +25,7 @@ import ai.djl.serving.models.ModelManager; import ai.djl.serving.plugins.DependencyManager; import ai.djl.serving.plugins.FolderScanPluginManager; +import ai.djl.serving.util.ClusterConfig; import ai.djl.serving.util.ConfigManager; import ai.djl.serving.util.Connector; import ai.djl.serving.util.ServerGroups; @@ -194,7 +195,6 @@ public List start() GeneralSecurityException, ServerStartupException { long begin = System.nanoTime(); - stopped.set(false); String version = Engine.getDjlVersion(); logger.info("Starting djl-serving: {} ...", version); @@ -204,6 +204,8 @@ public List start() pluginManager.loadPlugins(true); + initMultiNode(); + try { initModelStore(); } catch (BadWorkflowException | CompletionException e) { @@ -211,6 +213,7 @@ public List start() "Failed to initialize startup models and workflows", e); } + stopped.set(false); Connector inferenceConnector = configManager.getConnector(Connector.ConnectorType.INFERENCE); Connector managementConnector = @@ -273,6 +276,41 @@ public void stop() { serverGroups.reset(); } + private void initMultiNode() + throws GeneralSecurityException, + IOException, + InterruptedException, + ServerStartupException { + ClusterConfig cc = ClusterConfig.getInstance(); + int clusterSize = cc.getClusterSize(); + if (clusterSize > 1) { + Connector multiNodeConnector = + configManager.getConnector(Connector.ConnectorType.CLUSTER); + multiNodeConnector.clean(); + + EventLoopGroup serverGroup = serverGroups.getServerGroup(); + EventLoopGroup workerGroup = serverGroups.getChildGroup(); + + ChannelFuture future = initializeServer(multiNodeConnector, serverGroup, workerGroup); + + // start download model here + cc.countDown(); + + logger.info("Waiting for all worker nodes ready ..."); + cc.await(); + + future.channel().close(); + serverGroups.shutdown(true); + serverGroups.reset(); + + String status = cc.getError(); + if (status != null) { + throw new ServerStartupException("Failed to initialize cluster: " + status); + } + logger.info("Cluster initialized with {} nodes.", clusterSize); + } + } + private ChannelFuture initializeServer( Connector connector, EventLoopGroup serverGroup, EventLoopGroup workerGroup) throws InterruptedException, IOException, GeneralSecurityException { @@ -486,7 +524,7 @@ String mapModelUrl(Path path) { } catch (MalformedURLException e) { throw new AssertionError("Invalid path: " + path, e); } catch (IOException e) { - logger.warn("Failed to access file: " + path, e); + logger.warn("Failed to access file: {}", path, e); return null; } } diff --git a/serving/src/main/java/ai/djl/serving/ServerInitializer.java b/serving/src/main/java/ai/djl/serving/ServerInitializer.java index 4393bdabc..dae4349a4 100644 --- a/serving/src/main/java/ai/djl/serving/ServerInitializer.java +++ b/serving/src/main/java/ai/djl/serving/ServerInitializer.java @@ -13,6 +13,7 @@ package ai.djl.serving; import ai.djl.serving.http.AdapterManagementRequestHandler; +import ai.djl.serving.http.ClusterRequestHandler; import ai.djl.serving.http.ConfigurableHttpRequestHandler; import ai.djl.serving.http.InferenceRequestHandler; import ai.djl.serving.http.InvalidRequestHandler; @@ -74,6 +75,9 @@ public void initChannel(Channel ch) { case INFERENCE: pipeline.addLast("inference", new InferenceRequestHandler()); break; + case CLUSTER: + pipeline.addLast("cluster", new ClusterRequestHandler()); + break; case BOTH: default: pipeline.addLast(new ConfigurableHttpRequestHandler(pluginManager)); diff --git a/serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java new file mode 100644 index 000000000..2f7fb0ff7 --- /dev/null +++ b/serving/src/main/java/ai/djl/serving/http/ClusterRequestHandler.java @@ -0,0 +1,105 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.http; + +import ai.djl.ModelException; +import ai.djl.serving.util.ClusterConfig; +import ai.djl.serving.util.NettyUtils; +import ai.djl.util.Utils; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.QueryStringDecoder; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; + +/** A class handling inbound HTTP requests for the cluster management API. */ +public class ClusterRequestHandler extends HttpRequestHandler { + + private static final Logger logger = LoggerFactory.getLogger(ClusterRequestHandler.class); + + private ClusterConfig config = ClusterConfig.getInstance(); + + /** {@inheritDoc} */ + @Override + public boolean acceptInboundMessage(Object msg) throws Exception { + if (super.acceptInboundMessage(msg)) { + FullHttpRequest req = (FullHttpRequest) msg; + return req.uri().startsWith("/cluster/"); + } + return false; + } + + /** {@inheritDoc} */ + @Override + protected void handleRequest( + ChannelHandlerContext ctx, + FullHttpRequest req, + QueryStringDecoder decoder, + String[] segments) + throws ModelException { + switch (segments[2]) { + case "sshkey": + Path home = Paths.get(System.getProperty("user.home")).resolve(".ssh"); + Path file = home.resolve("id_rsa.pub"); + if (Files.notExists(file)) { + sshkeygen(home.resolve("id_rsa").toString()); + } + NettyUtils.sendFile(ctx, file, false); + return; + case "status": + List messages = decoder.parameters().get("message"); + if (messages.size() != 1) { + NettyUtils.sendJsonResponse(ctx, new StatusResponse("Invalid request")); + return; + } else if (!"OK".equals(messages.get(0))) { + config.setError(messages.get(0)); + } + config.countDown(); + NettyUtils.sendJsonResponse(ctx, new StatusResponse("OK")); + return; + default: + throw new ResourceNotFoundException(); + } + } + + private void sshkeygen(String rsaFile) { + try { + String[] commands = {"ssh-keygen", "-q", "-t", "rsa", "-N", "''", "-f", rsaFile}; + Process exec = new ProcessBuilder(commands).redirectErrorStream(true).start(); + String logOutput; + try (InputStream is = exec.getInputStream()) { + logOutput = Utils.toString(is); + } + int exitCode = exec.waitFor(); + if (0 != exitCode) { + logger.error("Generate ssh key failed: {}", logOutput); + config.setError(logOutput); + throw new IllegalStateException("Generate ssh key failed"); + } else { + logger.debug(logOutput); + } + } catch (IOException | InterruptedException e) { + config.setError("Generate ssh key failed"); + throw new IllegalStateException("Generate ssh key failed", e); + } + } +} diff --git a/serving/src/main/java/ai/djl/serving/util/ClusterConfig.java b/serving/src/main/java/ai/djl/serving/util/ClusterConfig.java new file mode 100644 index 000000000..be09bf246 --- /dev/null +++ b/serving/src/main/java/ai/djl/serving/util/ClusterConfig.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.util; + +import ai.djl.util.Utils; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** A class that holds cluster configurations. */ +public final class ClusterConfig { + + private static final ClusterConfig INSTANCE = new ClusterConfig(); + + private int clusterSize; + private CountDownLatch latch; + private String error; + + private ClusterConfig() { + clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE", "1")); + latch = new CountDownLatch(clusterSize); + } + + /** + * Returns the {@code ClusterConfig} singleton object. + * + * @return the {@code ClusterConfig} singleton object + */ + public static ClusterConfig getInstance() { + return INSTANCE; + } + + /** + * Returns the cluster size. + * + * @return the cluster size + */ + public int getClusterSize() { + return clusterSize; + } + + /** + * Returns the error status message. + * + * @return the error status message + */ + public String getError() { + return error; + } + + /** + * Sets the error status message. + * + * @param error the error status message + */ + public void setError(String error) { + this.error = error; + } + + /** Decreases the number of waiting workers. */ + public void countDown() { + latch.countDown(); + } + + /** + * Causes current threads to wait until all workers are ready. + * + * @throws InterruptedException if current thread is interrupted + */ + public void await() throws InterruptedException { + // TODO: support per model timeout + int timeout = Integer.parseInt(Utils.getenv("MODEL_LOADING_TIMEOUT", "240")); + if (!latch.await(timeout, TimeUnit.SECONDS)) { + error = "Worker nodes timed out"; + } + } +} diff --git a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java index 573aae278..5513e930a 100644 --- a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java +++ b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java @@ -57,6 +57,7 @@ public final class ConfigManager { private static final String INFERENCE_ADDRESS = "inference_address"; private static final String MANAGEMENT_ADDRESS = "management_address"; + private static final String CLUSTER_ADDRESS = "cluster_address"; private static final String LOAD_MODELS = "load_models"; private static final String WAIT_MODEL_LOADING = "wait_model_loading"; private static final String ALLOW_MULTI_STATUS = "allow_multi_status"; @@ -236,10 +237,18 @@ private boolean onError(String key) { */ public Connector getConnector(Connector.ConnectorType type) { String binding; - if (type == Connector.ConnectorType.MANAGEMENT) { - binding = prop.getProperty(MANAGEMENT_ADDRESS, "http://127.0.0.1:8080"); - } else { - binding = prop.getProperty(INFERENCE_ADDRESS, "http://127.0.0.1:8080"); + switch (type) { + case MANAGEMENT: + binding = prop.getProperty(MANAGEMENT_ADDRESS, "http://127.0.0.1:8080"); + break; + case CLUSTER: + binding = prop.getProperty(CLUSTER_ADDRESS, "http://127.0.0.1:8888"); + break; + case INFERENCE: + case BOTH: + default: + binding = prop.getProperty(INFERENCE_ADDRESS, "http://127.0.0.1:8080"); + break; } return Connector.parse(binding, type); } diff --git a/serving/src/main/java/ai/djl/serving/util/Connector.java b/serving/src/main/java/ai/djl/serving/util/Connector.java index c797a590b..ffe4346a4 100644 --- a/serving/src/main/java/ai/djl/serving/util/Connector.java +++ b/serving/src/main/java/ai/djl/serving/util/Connector.java @@ -257,6 +257,7 @@ public String toString() { /** An enum represents type of connector. */ public enum ConnectorType { + CLUSTER, INFERENCE, MANAGEMENT, BOTH