Skip to content

Commit

Permalink
[serving] Adds workflow model loading for SageMaker
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 25, 2023
1 parent 8199a6f commit ffbb357
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 127 deletions.
18 changes: 0 additions & 18 deletions serving/src/main/java/ai/djl/serving/Arguments.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ public final class Arguments {
private String configFile;
private String modelStore;
private String[] models;
private String[] workflows;
private boolean help;

/**
Expand All @@ -40,7 +39,6 @@ public Arguments(CommandLine cmd) {
configFile = cmd.getOptionValue("config-file");
modelStore = cmd.getOptionValue("model-store");
models = cmd.getOptionValues("models");
workflows = cmd.getOptionValues("workflows");
help = cmd.hasOption("help");
}

Expand Down Expand Up @@ -74,13 +72,6 @@ public static Options getOptions() {
.argName("MODELS-STORE")
.desc("Model store location where models can be loaded.")
.build());
options.addOption(
Option.builder("w")
.longOpt("workflows")
.hasArgs()
.argName("WORKFLOWS")
.desc("Workflows to be loaded at startup.")
.build());
options.addOption(
Option.builder("i")
.longOpt("install")
Expand Down Expand Up @@ -140,15 +131,6 @@ public String[] getModels() {
return models;
}

/**
* Returns the workflow urls that specified in command line.
*
* @return the workflow urls that specified in command line
*/
public String[] getWorkflows() {
return workflows;
}

/**
* Returns if the command line has help option.
*
Expand Down
100 changes: 25 additions & 75 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.GeneralSecurityException;
Expand Down Expand Up @@ -186,8 +185,7 @@ public List<ChannelFuture> start()

try {
initModelStore();
initWorkflows();
} catch (URISyntaxException | BadWorkflowException e) {
} catch (BadWorkflowException e) {
throw new ServerStartupException(
"Failed to initialize startup models and workflows", e);
}
Expand Down Expand Up @@ -298,7 +296,7 @@ private ChannelFuture initializeServer(
return f;
}

private void initModelStore() throws IOException {
private void initModelStore() throws IOException, BadWorkflowException {
Set<String> startupModels = ModelManager.getInstance().getStartupWorkflows();

String loadModels = configManager.getLoadModels();
Expand Down Expand Up @@ -372,22 +370,28 @@ private void initModelStore() throws IOException {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}

ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
modelName,
modelUrl,
version,
engineName,
deviceMapping,
Input.class,
Output.class,
-1,
-1,
-1,
-1,
-1,
-1);
Workflow workflow = new Workflow(modelInfo);
Workflow workflow;
if (WorkflowDefinition.isWorkflow(modelUrl)) {
URI uri = URI.create(modelUrl);
workflow = WorkflowDefinition.parse(uri, uri.toURL().openStream()).toWorkflow();
} else {
ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
modelName,
modelUrl,
version,
engineName,
deviceMapping,
Input.class,
Output.class,
-1,
-1,
-1,
-1,
-1,
-1);
workflow = new Workflow(modelInfo);
}
CompletableFuture<Void> f =
modelManager
.registerWorkflow(workflow)
Expand Down Expand Up @@ -416,60 +420,6 @@ private void initModelStore() throws IOException {
}
}

private void initWorkflows() throws IOException, URISyntaxException, BadWorkflowException {
Set<String> startupWorkflows = ModelManager.getInstance().getStartupWorkflows();
String loadWorkflows = configManager.getLoadWorkflows();
if (loadWorkflows == null || loadWorkflows.isEmpty()) {
return;
}

ModelManager modelManager = ModelManager.getInstance();
String[] urls = loadWorkflows.split("[, ]+");

for (String url : urls) {
logger.info("Initializing workflow: {}", url);
Matcher matcher = MODEL_STORE_PATTERN.matcher(url);
if (!matcher.matches()) {
throw new AssertionError("Invalid model store url: " + url);
}
String endpoint = matcher.group(2);
String workflowUrlString = matcher.group(3);
String workflowName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
workflowName = tokens[0];
} else {
workflowName = ModelInfo.inferModelNameFromUrl(workflowUrlString);
}

URL workflowUrl = new URL(workflowUrlString);
Workflow workflow =
WorkflowDefinition.parse(workflowUrl.toURI(), workflowUrl.openStream())
.toWorkflow();

CompletableFuture<Void> f =
modelManager
.registerWorkflow(workflow)
.exceptionally(
t -> {
logger.error("Failed register workflow", t);
// delay 3 seconds, allows REST API to send PING
// response (health check)
try {
Thread.sleep(3000);
} catch (InterruptedException ignore) {
// ignore
}
stop();
return null;
});
if (configManager.waitModelLoading()) {
f.join();
}
startupWorkflows.add(workflowName);
}
}

String mapModelUrl(Path path) {
try {
logger.info("Found file in model_store: {}", path);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.netty.util.CharsetUtil;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.ArrayList;
Expand Down Expand Up @@ -190,22 +191,33 @@ private void handleRegisterModel(
req = new LoadModelRequest(decoder);
}

ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
req.getModelName(),
req.getModelUrl(),
req.getVersion(),
req.getEngineName(),
req.getDeviceName(),
Input.class,
Output.class,
req.getJobQueueSize(),
req.getMaxIdleSeconds(),
req.getMaxBatchDelayMillis(),
req.getBatchSize(),
req.getMinWorkers(),
req.getMaxWorkers());
Workflow workflow = new Workflow(modelInfo);
Workflow workflow;
if (WorkflowDefinition.isWorkflow(req.getModelUrl())) {
try {
URI uri = URI.create(req.getModelUrl());
workflow = WorkflowDefinition.parse(uri, uri.toURL().openStream()).toWorkflow();
} catch (IOException | BadWorkflowException e) {
NettyUtils.sendError(ctx, e.getCause());
return;
}
} else {
ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
req.getModelName(),
req.getModelUrl(),
req.getVersion(),
req.getEngineName(),
req.getDeviceName(),
Input.class,
Output.class,
req.getJobQueueSize(),
req.getMaxIdleSeconds(),
req.getMaxBatchDelayMillis(),
req.getBatchSize(),
req.getMinWorkers(),
req.getMaxWorkers());
workflow = new Workflow(modelInfo);
}
final ModelManager modelManager = ModelManager.getInstance();
CompletableFuture<Void> f =
modelManager
Expand Down
16 changes: 0 additions & 16 deletions serving/src/main/java/ai/djl/serving/util/ConfigManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ public final class ConfigManager {
private static final String INFERENCE_ADDRESS = "inference_address";
private static final String MANAGEMENT_ADDRESS = "management_address";
private static final String LOAD_MODELS = "load_models";
private static final String LOAD_WORKFLOWS = "load_workflows";
private static final String WAIT_MODEL_LOADING = "wait_model_loading";
private static final String ALLOW_MULTI_STATUS = "allow_multi_status";
private static final String NUMBER_OF_NETTY_THREADS = "number_of_netty_threads";
Expand Down Expand Up @@ -110,10 +109,6 @@ private ConfigManager(Arguments args) {
if (models != null) {
prop.setProperty(LOAD_MODELS, String.join(",", models));
}
String[] workflows = args.getWorkflows();
if (workflows != null) {
prop.setProperty(LOAD_WORKFLOWS, String.join(",", workflows));
}
for (Map.Entry<String, String> env : Utils.getenv().entrySet()) {
String key = env.getKey();
if (key.startsWith("SERVING_")) {
Expand Down Expand Up @@ -266,15 +261,6 @@ public String getLoadModels() {
return prop.getProperty(LOAD_MODELS);
}

/**
* Returns the workflow urls that to be loaded at startup.
*
* @return the workflow urls that to be loaded at startup
*/
public String getLoadWorkflows() {
return prop.getProperty(LOAD_WORKFLOWS);
}

/**
* Returns the devices the default model will be loaded on at startup.
*
Expand Down Expand Up @@ -440,8 +426,6 @@ public String dumpConfigurations() {
+ (getModelStore() == null ? "N/A" : getModelStore())
+ "\nInitial Models: "
+ (getLoadModels() == null ? "N/A" : getLoadModels())
+ "\nInitial Workflows: "
+ (getLoadWorkflows() == null ? "N/A" : getLoadWorkflows())
+ "\nNetty threads: "
+ getNettyThreads()
+ "\nMaximum Request Size: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -137,6 +138,34 @@ public static WorkflowDefinition parse(URI uri, Reader input) {
}
}

/**
* Returns true if the url points to a workflow definition file.
*
* @param url the workflow url
* @return true if the url points to a workflow definition file
*/
public static boolean isWorkflow(String url) {
if (url.endsWith(".json") || url.endsWith(".yml") || url.endsWith(".yaml")) {
return true;
}
URI uri = URI.create(url);
String uriPath = uri.getPath();
if (uriPath == null) {
uriPath = uri.getSchemeSpecificPart();
}
if (uriPath.startsWith("/") && System.getProperty("os.name").startsWith("Win")) {
uriPath = uriPath.substring(1);
}

Path path = Paths.get(uriPath);
if (Files.exists(path)) {
return Files.isRegularFile(path.resolve("workflow.json"))
|| Files.isRegularFile(path.resolve("workflow.yml"))
|| Files.isRegularFile(path.resolve("workflow.yaml"));
}
return false;
}

/**
* Converts the {@link WorkflowDefinition} into a workflow.
*
Expand Down
3 changes: 1 addition & 2 deletions serving/src/test/resources/workflow.config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
inference_address=https://127.0.0.1:8443
management_address=https://127.0.0.1:8443
model_store=build/models
load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:*]=https://resources.djl.ai/test-models/mlp.tar.gz
load_workflows=https://resources.djl.ai/test-models/basic-serving-workflow.json
load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:*]=https://resources.djl.ai/test-models/mlp.tar.gz,https://resources.djl.ai/test-models/basic-serving-workflow.json
private_key_file=src/test/resources/key.pem
certificate_file=src/test/resources/certs.pem
max_request_size=10485760

0 comments on commit ffbb357

Please sign in to comment.