Skip to content

Commit

Permalink
Refactor workflow parsing
Browse files Browse the repository at this point in the history
1. Removed global preferences for workflow definition
2. Updated README
  • Loading branch information
frankfliu committed Apr 25, 2023
1 parent 1ad4b15 commit fcf6808
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 103 deletions.
12 changes: 3 additions & 9 deletions serving/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ usage: djl-serving [OPTIONS]
-h,--help Print this help.
-m,--models <MODELS> Models to be loaded at startup.
-s,--model-store <MODELS-STORE> Model store location where models can be loaded.
-w,--workflows <WORKFLOWS> Workflows to be loaded at startup.
```

Details about the models, model-store, and workflows can be found in the equivalent configuration properties.
Expand Down Expand Up @@ -65,7 +64,7 @@ model_store=build/models

**Load Models**

The `load_models` config property can be used to define a list of models to be loaded.
The `load_models` config property can be used to define a list of models (or workflows) to be loaded.
The list should be defined as a comma separated list of urls to load models from.

Each model can be defined either as a URL directly or optionally with prepended endpoint data like `[EndpointData]=modelUrl`.
Expand All @@ -89,15 +88,10 @@ load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:*]=htt

**Workflows**

Use the `load_workflows` config property to define initial workflows that should be loaded on startup.
It should be a comma separated list of workflow URLs.

You can also specify the device that the model should be loaded on by using `modelUrl:deviceNames`.
The `deviceNames` matches the format used in the `load_models` property described above.
An example is shown below:
Use the `load_models` config property to define initial workflows that should be loaded on startup.

```properties
load_workflows=https://resources.djl.ai/test-models/basic-serving-workflow.json
load_models=https://resources.djl.ai/test-models/basic-serving-workflow.json
```

View the [workflow documentation](workflows.md) to see more information about workflows and their configuration format.
Expand Down
9 changes: 0 additions & 9 deletions serving/docs/workflows.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ As the system is built in YAML, the overall structure is a configuration object
name: "MyWorkflow"
version: "1.2.0"
# Default model properties based on https://github.com/pytorch/serve/blob/master/docs/workflows.md#workflow-model-properties
# optional
minWorkers: 1
maxWorkers: 4
batchSize: 3
maxBatchDelayMillis: 5000
retryAttempts: 3
timeout: 5000
# Defined below
models: ...
functions: ...
Expand Down
9 changes: 5 additions & 4 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ private void initModelStore() throws IOException, BadWorkflowException {
String version = null;
String engineName = null;
String deviceMapping = null;
String modelName;
String modelName = null;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
modelName = tokens[0];
Expand All @@ -366,15 +366,16 @@ private void initModelStore() throws IOException, BadWorkflowException {
if (tokens.length > 3) {
deviceMapping = tokens[3];
}
} else {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}

Workflow workflow;
URI uri = WorkflowDefinition.toWorkflowUri(modelUrl);
if (uri != null) {
workflow = WorkflowDefinition.parse(uri, uri.toURL().openStream()).toWorkflow();
workflow = WorkflowDefinition.parse(modelName, uri).toWorkflow();
} else {
if (modelName == null) {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}
ModelInfo<Input, Output> modelInfo =
new ModelInfo<>(
modelName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -195,7 +193,7 @@ private void handleRegisterModel(
URI uri = WorkflowDefinition.toWorkflowUri(req.getModelUrl());
if (uri != null) {
try {
workflow = WorkflowDefinition.parse(uri, uri.toURL().openStream()).toWorkflow();
workflow = WorkflowDefinition.parse(req.getModelName(), uri).toWorkflow();
} catch (IOException | BadWorkflowException e) {
NettyUtils.sendError(ctx, e.getCause());
return;
Expand Down Expand Up @@ -252,9 +250,8 @@ private void handleRegisterWorkflow(
NettyUtils.getParameter(decoder, LoadModelRequest.SYNCHRONOUS, "true"));

try {
URL url = new URL(workflowUrl);
Workflow workflow =
WorkflowDefinition.parse(url.toURI(), url.openStream()).toWorkflow();
URI uri = URI.create(workflowUrl);
Workflow workflow = WorkflowDefinition.parse(null, uri).toWorkflow();
String workflowName = workflow.getName();

final ModelManager modelManager = ModelManager.getInstance();
Expand All @@ -275,8 +272,7 @@ private void handleRegisterWorkflow(
NettyUtils.sendJsonResponse(
ctx, new StatusResponse(msg), HttpResponseStatus.ACCEPTED);
}

} catch (URISyntaxException | IOException | BadWorkflowException e) {
} catch (IOException | BadWorkflowException e) {
NettyUtils.sendError(ctx, e.getCause());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.FilenameUtils;
import ai.djl.serving.util.MutableClassLoader;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.util.WlmConfigManager;
import ai.djl.serving.workflow.WorkflowExpression.Item;
import ai.djl.serving.workflow.function.WorkflowFunction;
import ai.djl.util.ClassLoaderUtils;
Expand Down Expand Up @@ -70,11 +70,6 @@ public class WorkflowDefinition {
@SerializedName("configs")
Map<String, Map<String, Object>> configs;

int queueSize;
int maxIdleSeconds;
int maxBatchDelayMillis;
int batchSize;

public static final Gson GSON =
JsonUtils.builder()
.registerTypeAdapter(ModelInfo.class, new ModelDefinitionDeserializer())
Expand All @@ -90,30 +85,31 @@ public class WorkflowDefinition {
* @throws IOException if it fails to load the file for parsing
*/
public static WorkflowDefinition parse(Path path) throws IOException {
return parse(path.toUri(), Files.newBufferedReader(path));
return parse(null, path.toUri());
}

/**
* Parses a new {@link WorkflowDefinition} from an input stream.
*
* @param name the workflow name
* @param uri the uri of the file
* @param input the input
* @return the parsed {@link WorkflowDefinition}
* @throws IOException if read from uri failed
*/
public static WorkflowDefinition parse(URI uri, InputStream input) {
return parse(uri, new InputStreamReader(input, StandardCharsets.UTF_8));
public static WorkflowDefinition parse(String name, URI uri) throws IOException {
String type = FilenameUtils.getFileExtension(Objects.requireNonNull(uri.toString()));
try (InputStream is = uri.toURL().openStream();
Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {
WorkflowDefinition wd = parse(type, reader);
if (name != null) {
wd.name = name;
}
return wd;
}
}

/**
* Parses a new {@link WorkflowDefinition} from a reader.
*
* @param uri the uri of the file
* @param input the input
* @return the parsed {@link WorkflowDefinition}
*/
public static WorkflowDefinition parse(URI uri, Reader input) {
String fileName = Objects.requireNonNull(uri.toString());
if (fileName.endsWith(".yml") || fileName.endsWith(".yaml")) {
private static WorkflowDefinition parse(String type, Reader input) {
if ("yml".equalsIgnoreCase(type) || "yaml".equalsIgnoreCase(type)) {
try {
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
Class<?> clazz = Class.forName("org.yaml.snakeyaml.Yaml", true, cl);
Expand All @@ -131,10 +127,10 @@ public static WorkflowDefinition parse(URI uri, Reader input) {
+ " build.gradle.",
e);
}
} else if (fileName.endsWith(".json")) {
} else if ("json".equalsIgnoreCase(type)) {
return GSON.fromJson(input, WorkflowDefinition.class);
} else {
throw new IllegalArgumentException("Unexpected file type in workflow file: " + uri);
throw new IllegalArgumentException("Unexpected file type: " + type);
}
}

Expand Down Expand Up @@ -194,23 +190,9 @@ public static URI toWorkflowUri(String link) {
*/
public Workflow toWorkflow() throws BadWorkflowException {
if (models != null) {
WlmConfigManager wlmc = WlmConfigManager.getInstance();
for (Entry<String, ModelInfo<Input, Output>> emd : models.entrySet()) {
ModelInfo<Input, Output> md = emd.getValue();
md.setId(emd.getKey());
md.setQueueSize(firstValid(md.getQueueSize(), queueSize, wlmc.getJobQueueSize()));
md.setMaxIdleSeconds(
firstValid(
md.getMaxIdleSeconds(), maxIdleSeconds, wlmc.getMaxIdleSeconds()));
md.setMaxBatchDelayMillis(
firstValid(
md.getMaxBatchDelayMillis(),
maxBatchDelayMillis,
wlmc.getMaxBatchDelayMillis()));
md.setBatchSize(firstValid(md.getBatchSize(), batchSize, wlmc.getBatchSize()));
if (name == null) {
name = emd.getKey();
}
}
}

Expand All @@ -230,15 +212,6 @@ public Workflow toWorkflow() throws BadWorkflowException {
return new Workflow(name, version, models, expressions, configs, loadedFunctions);
}

private int firstValid(int... inputs) {
for (int input : inputs) {
if (input > 0) {
return input;
}
}
return 0;
}

private static final class ModelDefinitionDeserializer
implements JsonDeserializer<ModelInfo<Input, Output>> {

Expand Down
12 changes: 0 additions & 12 deletions serving/src/test/java/ai/djl/serving/WorkflowTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,6 @@ public void testFunctions() throws IOException, BadWorkflowException {
runWorkflow(workflowFile, zeroInput);
}

@Test
public void testGlobalPerf() throws IOException, BadWorkflowException {
Path workflowFile = Paths.get("src/test/resources/workflows/globalPerf.json");
Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow();
ModelInfo<Input, Output> m = workflow.getModels().stream().findFirst().get();

Assert.assertEquals(m.getQueueSize(), 101);
Assert.assertEquals(m.getMaxIdleSeconds(), 61);
Assert.assertEquals(m.getMaxBatchDelayMillis(), 301);
Assert.assertEquals(m.getBatchSize(), 2);
}

@Test
public void testLocalPerf() throws IOException, BadWorkflowException {
Path workflowFile = Paths.get("src/test/resources/workflows/localPerf.json");
Expand Down
12 changes: 0 additions & 12 deletions serving/src/test/resources/workflows/globalPerf.json

This file was deleted.

4 changes: 0 additions & 4 deletions serving/src/test/resources/workflows/localPerf.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
{
"queueSize": 101,
"maxIdleSeconds": 61,
"maxBatchDelayMillis": 301,
"batchSize": 2,
"models": {
"m": {
"modelUrls": "https://resources.djl.ai/test-models/mlp.tar.gz",
Expand Down

0 comments on commit fcf6808

Please sign in to comment.