Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serving] Download model while initialize multi-node cluster #2198

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[serving] Download model while initialize multi-node cluster
  • Loading branch information
frankfliu committed Jul 19, 2024
commit f602919c1fd83d984316abe419bdcec291193d52
226 changes: 23 additions & 203 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,26 @@
package ai.djl.serving;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.metric.Dimension;
import ai.djl.metric.Metric;
import ai.djl.metric.Unit;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.FilenameUtils;
import ai.djl.serving.http.ServerStartupException;
import ai.djl.serving.models.ModelManager;
import ai.djl.serving.plugins.DependencyManager;
import ai.djl.serving.plugins.FolderScanPluginManager;
import ai.djl.serving.util.ClusterConfig;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.Connector;
import ai.djl.serving.util.ModelStore;
import ai.djl.serving.util.ServerGroups;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkerPoolConfig;
import ai.djl.serving.workflow.BadWorkflowException;
import ai.djl.serving.workflow.Workflow;
import ai.djl.serving.workflow.WorkflowDefinition;
import ai.djl.util.RandomUtils;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;

import io.netty.bootstrap.ServerBootstrap;
Expand All @@ -55,37 +53,21 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedWriter;
import java.io.IOException;
import java.lang.management.MemoryUsage;
import java.net.MalformedURLException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/** The main entry point for model server. */
public class ModelServer {

private static final Logger logger = LoggerFactory.getLogger(ModelServer.class);
private static final Logger SERVER_METRIC = LoggerFactory.getLogger("server_metric");
private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[?([^?]+?)]?=)?(.+)");

private ServerGroups serverGroups;
private List<ChannelFuture> futures = new ArrayList<>(2);
Expand Down Expand Up @@ -204,11 +186,16 @@ public List<ChannelFuture> start()

pluginManager.loadPlugins(true);

initMultiNode();

try {
initModelStore();
} catch (BadWorkflowException | CompletionException e) {
ModelStore modelStore = ModelStore.getInstance();
modelStore.initialize();

List<Workflow> workflows = modelStore.getWorkflows();

initMultiNode(workflows);

loadModels(workflows);
} catch (BadWorkflowException | ModelException | CompletionException e) {
throw new ServerStartupException(
"Failed to initialize startup models and workflows", e);
}
Expand Down Expand Up @@ -276,11 +263,12 @@ public void stop() {
serverGroups.reset();
}

private void initMultiNode()
private void initMultiNode(List<Workflow> workflows)
throws GeneralSecurityException,
IOException,
InterruptedException,
ServerStartupException {
ServerStartupException,
ModelException {
ClusterConfig cc = ClusterConfig.getInstance();
int clusterSize = cc.getClusterSize();
if (clusterSize > 1) {
Expand All @@ -293,7 +281,12 @@ private void initMultiNode()

ChannelFuture future = initializeServer(multiNodeConnector, serverGroup, workerGroup);

// start download model here
// download the models
for (Workflow workflow : workflows) {
for (WorkerPoolConfig<Input, Output> model : workflow.getWpcs()) {
model.initialize();
}
}
cc.countDown();

logger.info("Waiting for all worker nodes ready ...");
Expand Down Expand Up @@ -369,116 +362,9 @@ private ChannelFuture initializeServer(
return f;
}

private void initModelStore() throws IOException, BadWorkflowException {
Set<String> startupModels = ModelManager.getInstance().getStartupWorkflows();

String loadModels = configManager.getLoadModels();
Path modelStore = configManager.getModelStore();
if (loadModels == null || loadModels.isEmpty()) {
loadModels = "ALL";
}

private void loadModels(List<Workflow> workflows) {
ModelManager modelManager = ModelManager.getInstance();
Set<String> urls = new HashSet<>();
if ("NONE".equalsIgnoreCase(loadModels)) {
// to disable load all models from model store
return;
} else if ("ALL".equalsIgnoreCase(loadModels)) {
if (modelStore == null) {
logger.warn("Model store is not configured.");
return;
}

if (Files.isDirectory(modelStore)) {
// contains only directory or archive files
boolean isMultiModelsDirectory =
Files.list(modelStore)
.filter(p -> !p.getFileName().toString().startsWith("."))
.allMatch(
p ->
Files.isDirectory(p)
|| FilenameUtils.isArchiveFile(
p.toString()));

if (isMultiModelsDirectory) {
// Check folders to see if they can be models as well
try (Stream<Path> stream = Files.list(modelStore)) {
urls.addAll(
stream.map(this::mapModelUrl)
.filter(Objects::nonNull)
.collect(Collectors.toList()));
}
} else {
// Check if root model store folder contains a model
String url = mapModelUrl(modelStore);
if (url != null) {
urls.add(url);
}
}
} else {
logger.warn("Model store path is not found: {}", modelStore);
}
} else {
String[] modelsUrls = loadModels.split("[, ]+");
urls.addAll(Arrays.asList(modelsUrls));
}

String huggingFaceModelId = Utils.getEnvOrSystemProperty("HF_MODEL_ID");
if (huggingFaceModelId != null) {
urls.add(createHuggingFaceModel(huggingFaceModelId));
}

for (String url : urls) {
logger.info("Initializing model: {}", url);
Matcher matcher = MODEL_STORE_PATTERN.matcher(url);
if (!matcher.matches()) {
throw new AssertionError("Invalid model store url: " + url);
}
String endpoint = matcher.group(2);
String modelUrl = matcher.group(3);
String version = null;
String engineName = null;
String deviceMapping = null;
String modelName = null;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
modelName = tokens[0];
if (tokens.length > 1) {
version = tokens[1].isEmpty() ? null : tokens[1];
}
if (tokens.length > 2) {
engineName = tokens[2].isEmpty() ? null : tokens[2];
}
if (tokens.length > 3) {
deviceMapping = tokens[3];
}
}

Workflow workflow;
URI uri = WorkflowDefinition.toWorkflowUri(modelUrl);
if (uri != null) {
workflow = WorkflowDefinition.parse(modelName, uri).toWorkflow();
} else {
if (modelName == null) {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}
ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
modelName,
modelUrl,
version,
engineName,
deviceMapping,
Input.class,
Output.class,
-1,
-1,
-1,
-1,
-1,
-1);
workflow = new Workflow(modelInfo);
}
for (Workflow workflow : workflows) {
CompletableFuture<Void> f = modelManager.registerWorkflow(workflow);
f.exceptionally(
t -> {
Expand All @@ -499,33 +385,6 @@ private void initModelStore() throws IOException, BadWorkflowException {
if (configManager.waitModelLoading()) {
f.join();
}
startupModels.add(modelName);
}
}

String mapModelUrl(Path path) {
try {
if (!Files.exists(path)
|| Files.isHidden(path)
|| (!Files.isDirectory(path)
&& !FilenameUtils.isArchiveFile(path.toString()))) {
return null;
}

if (Files.list(path).findFirst().isEmpty()) {
return null;
}

path = Utils.getNestedModelDir(path);
String url = path.toUri().toURL().toString();
String modelName = ModelInfo.inferModelNameFromUrl(url);
logger.info("Found model {}={}", modelName, url);
return modelName + '=' + url;
} catch (MalformedURLException e) {
throw new AssertionError("Invalid path: " + path, e);
} catch (IOException e) {
logger.warn("Failed to access file: {}", path, e);
return null;
}
}

Expand All @@ -535,43 +394,4 @@ private static void printHelp(String msg, Options options) {
formatter.setWidth(120);
formatter.printHelp(msg, options);
}

private String createHuggingFaceModel(String modelId) throws IOException {
if (modelId.startsWith("djl://") || modelId.startsWith("s3://")) {
return modelId;
}
Path path = Paths.get(modelId);
if (Files.exists(path)) {
// modelId point to a local file
return mapModelUrl(path);
}

// TODO: Download the full model from HF
String hash = Utils.hash(modelId);
String downloadDir = Utils.getenv("SERVING_DOWNLOAD_DIR", null);
Path parent = downloadDir == null ? Utils.getCacheDir() : Paths.get(downloadDir);
Path huggingFaceModelDir = parent.resolve(hash);
String modelName = modelId.replaceAll("(\\W|^_)", "_");
if (Files.exists(huggingFaceModelDir)) {
logger.warn("HuggingFace Model {} already exists, use random model name", modelId);
return modelName + '_' + RandomUtils.nextInt() + '=' + huggingFaceModelDir;
}
String huggingFaceModelRevision = Utils.getEnvOrSystemProperty("HF_REVISION");
Properties huggingFaceProperties = new Properties();
huggingFaceProperties.put("option.model_id", modelId);
if (huggingFaceModelRevision != null) {
huggingFaceProperties.put("option.revision", huggingFaceModelRevision);
}
String task = Utils.getEnvOrSystemProperty("HF_TASK");
if (task != null) {
huggingFaceProperties.put("option.task", task);
}
Files.createDirectories(huggingFaceModelDir);
Path propertiesFile = huggingFaceModelDir.resolve("serving.properties");
try (BufferedWriter writer = Files.newBufferedWriter(propertiesFile)) {
huggingFaceProperties.store(writer, null);
}
logger.debug("Created serving.properties for model at path {}", propertiesFile);
return modelName + '=' + huggingFaceModelDir;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
package ai.djl.serving.http;

import ai.djl.ModelException;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.serving.util.ClusterConfig;
import ai.djl.serving.util.ModelStore;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkerPoolConfig;
import ai.djl.serving.workflow.Workflow;
import ai.djl.util.Utils;

import io.netty.channel.ChannelHandlerContext;
Expand All @@ -30,6 +36,8 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** A class handling inbound HTTP requests for the cluster management API. */
public class ClusterRequestHandler extends HttpRequestHandler {
Expand Down Expand Up @@ -65,6 +73,18 @@ protected void handleRequest(
}
NettyUtils.sendFile(ctx, file, false);
return;
case "models":
ModelStore modelStore = ModelStore.getInstance();
List<Workflow> workflows = modelStore.getWorkflows();
Map<String, String> map = new ConcurrentHashMap<>();
for (Workflow workflow : workflows) {
for (WorkerPoolConfig<Input, Output> wpc : workflow.getWpcs()) {
ModelInfo<Input, Output> model = (ModelInfo<Input, Output>) wpc;
map.put(model.getId(), model.getModelUrl());
}
}
NettyUtils.sendJsonResponse(ctx, map);
return;
case "status":
List<String> messages = decoder.parameters().get("message");
if (messages.size() != 1) {
Expand Down
Loading
Loading