Skip to content

Commit

Permalink
[serving] Detects engine to avoid uncessarily download MXNet engine (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Feb 15, 2023
1 parent 20bb6b2 commit 01b3648
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 55 deletions.
59 changes: 4 additions & 55 deletions serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ private void initModelStore() throws IOException {
}
Pair<String, Path> pair = ModelInfo.downloadModel(modelUrl);
if (engineName == null) {
engineName = inferEngine(pair.getValue(), pair.getKey());
engineName = ModelInfo.inferEngine(pair.getValue(), pair.getKey());
if (engineName == null) {
logger.warn("Failed to infer engine, skip url: {}", url);
continue;
Expand Down Expand Up @@ -460,7 +460,7 @@ private void initWorkflows() throws IOException, URISyntaxException, BadWorkflow
workflowName = tokens[0];
if (tokens.length > 1) {
Pair<String, Path> pair = ModelInfo.downloadModel(workflowUrlString);
String engineName = inferEngine(pair.getValue(), pair.getKey());
String engineName = ModelInfo.inferEngine(pair.getValue(), pair.getKey());
DependencyManager.getInstance().installEngine(engineName);
Engine engine = Engine.getEngine(engineName);
devices = parseDevices(tokens[1], engine, pair.getValue());
Expand Down Expand Up @@ -518,10 +518,10 @@ String mapModelUrl(Path path) {
String modelName = ModelInfo.inferModelNameFromUrl(url);
String engine;
if (Files.isDirectory(path)) {
engine = inferEngine(path, path.toFile().getName());
engine = ModelInfo.inferEngine(path, path.toFile().getName());
} else {
// .zip file
engine = inferEngineFromUrl(url);
engine = ModelInfo.inferEngineFromUrl(url);
}
if (engine == null) {
return null;
Expand All @@ -536,57 +536,6 @@ String mapModelUrl(Path path) {
}
}

private String inferEngineFromUrl(String modelUrl) {
try {
Pair<String, Path> pair = ModelInfo.downloadModel(modelUrl);
return inferEngine(pair.getValue(), pair.getKey());
} catch (IOException e) {
logger.warn("Failed to extract model: " + modelUrl, e);
return null;
}
}

private String inferEngine(Path modelDir, String modelName) {
modelDir = Utils.getNestedModelDir(modelDir);

Properties prop = ModelInfo.getServingProperties(modelDir);
String engine = prop.getProperty("engine");
if (engine != null) {
return engine;
}

modelName = prop.getProperty("option.modelName", modelName);
if (Files.isDirectory(modelDir.resolve("MAR-INF"))
|| Files.isRegularFile(modelDir.resolve("model.py"))
|| Files.isRegularFile(modelDir.resolve(modelName + ".py"))) {
// MMS/TorchServe
return "Python";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".pt"))
|| Files.isRegularFile(modelDir.resolve("model.pt"))) {
return "PyTorch";
} else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) {
return "TensorFlow";
} else if (Files.isRegularFile(modelDir.resolve(modelName + "-symbol.json"))) {
return "MXNet";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".onnx"))
|| Files.isRegularFile(modelDir.resolve("model.onnx"))) {
return "OnnxRuntime";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".trt"))
|| Files.isRegularFile(modelDir.resolve(modelName + ".uff"))) {
return "TensorRT";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".tflite"))) {
return "TFLite";
} else if (Files.isRegularFile(modelDir.resolve("model"))
|| Files.isRegularFile(modelDir.resolve("__model__"))
|| Files.isRegularFile(modelDir.resolve("inference.pdmodel"))) {
return "PaddlePaddle";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".json"))) {
return "XGBoost";
}
logger.warn("Failed to detect engine of the model: {}", modelDir);
return null;
}

private String[] parseDevices(String devices, Engine engine, Path modelDir) {
if ("*".equals(devices)) {
int gpuCount = engine.getGpuCount();
Expand Down
67 changes: 67 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,80 @@ public static String inferModelNameFromUrl(String url) {
return modelName;
}

/**
* Infers engine name from model URL.
*
* @param modelUrl the model URL
* @return the engine name
*/
public static String inferEngineFromUrl(String modelUrl) {
try {
Pair<String, Path> pair = ModelInfo.downloadModel(modelUrl);
return ModelInfo.inferEngine(pair.getValue(), pair.getKey());
} catch (IOException e) {
logger.warn("Failed to extract model: " + modelUrl, e);
return null;
}
}

/**
* Infers engine name from model directory.
*
* @param modelDir the model directory
* @param modelName the model name
* @return the engine name
*/
public static String inferEngine(Path modelDir, String modelName) {
modelDir = Utils.getNestedModelDir(modelDir);

Properties prop = getServingProperties(modelDir);
String engine = prop.getProperty("engine");
if (engine != null) {
return engine;
}

modelName = prop.getProperty("option.modelName", modelName);
if (Files.isDirectory(modelDir.resolve("MAR-INF"))
|| Files.isRegularFile(modelDir.resolve("model.py"))
|| Files.isRegularFile(modelDir.resolve(modelName + ".py"))) {
// MMS/TorchServe
return "Python";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".pt"))
|| Files.isRegularFile(modelDir.resolve("model.pt"))) {
return "PyTorch";
} else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) {
return "TensorFlow";
} else if (Files.isRegularFile(modelDir.resolve(modelName + "-symbol.json"))) {
return "MXNet";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".onnx"))
|| Files.isRegularFile(modelDir.resolve("model.onnx"))) {
return "OnnxRuntime";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".trt"))
|| Files.isRegularFile(modelDir.resolve(modelName + ".uff"))) {
return "TensorRT";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".tflite"))) {
return "TFLite";
} else if (Files.isRegularFile(modelDir.resolve("model"))
|| Files.isRegularFile(modelDir.resolve("__model__"))
|| Files.isRegularFile(modelDir.resolve("inference.pdmodel"))) {
return "PaddlePaddle";
} else if (Files.isRegularFile(modelDir.resolve(modelName + ".json"))) {
return "XGBoost";
}
logger.warn("Failed to detect engine of the model: {}", modelDir);
return null;
}

/**
* Returns the default device for this model if device is null.
*
* @param deviceName the device to use if it is not null
* @return a non-null device
*/
public Device withDefaultDevice(String deviceName) {
if (engineName == null && modelUrl != null) {
engineName = inferEngineFromUrl(modelUrl);
}
Engine engine = engineName != null ? Engine.getEngine(engineName) : Engine.getInstance();
if (deviceName == null) {
return engine.defaultDevice();
Expand Down

0 comments on commit 01b3648

Please sign in to comment.