diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index a5951e9d0..cae0ddf77 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -378,7 +378,7 @@ private void initModelStore() throws IOException { } Pair pair = ModelInfo.downloadModel(modelUrl); if (engineName == null) { - engineName = inferEngine(pair.getValue(), pair.getKey()); + engineName = ModelInfo.inferEngine(pair.getValue(), pair.getKey()); if (engineName == null) { logger.warn("Failed to infer engine, skip url: {}", url); continue; @@ -460,7 +460,7 @@ private void initWorkflows() throws IOException, URISyntaxException, BadWorkflow workflowName = tokens[0]; if (tokens.length > 1) { Pair pair = ModelInfo.downloadModel(workflowUrlString); - String engineName = inferEngine(pair.getValue(), pair.getKey()); + String engineName = ModelInfo.inferEngine(pair.getValue(), pair.getKey()); DependencyManager.getInstance().installEngine(engineName); Engine engine = Engine.getEngine(engineName); devices = parseDevices(tokens[1], engine, pair.getValue()); @@ -518,10 +518,10 @@ String mapModelUrl(Path path) { String modelName = ModelInfo.inferModelNameFromUrl(url); String engine; if (Files.isDirectory(path)) { - engine = inferEngine(path, path.toFile().getName()); + engine = ModelInfo.inferEngine(path, path.toFile().getName()); } else { // .zip file - engine = inferEngineFromUrl(url); + engine = ModelInfo.inferEngineFromUrl(url); } if (engine == null) { return null; @@ -536,57 +536,6 @@ String mapModelUrl(Path path) { } } - private String inferEngineFromUrl(String modelUrl) { - try { - Pair pair = ModelInfo.downloadModel(modelUrl); - return inferEngine(pair.getValue(), pair.getKey()); - } catch (IOException e) { - logger.warn("Failed to extract model: " + modelUrl, e); - return null; - } - } - - private String inferEngine(Path modelDir, String modelName) { - modelDir = Utils.getNestedModelDir(modelDir); - - Properties prop = ModelInfo.getServingProperties(modelDir); - String engine = prop.getProperty("engine"); - if (engine != null) { - return engine; - } - - modelName = prop.getProperty("option.modelName", modelName); - if (Files.isDirectory(modelDir.resolve("MAR-INF")) - || Files.isRegularFile(modelDir.resolve("model.py")) - || Files.isRegularFile(modelDir.resolve(modelName + ".py"))) { - // MMS/TorchServe - return "Python"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".pt")) - || Files.isRegularFile(modelDir.resolve("model.pt"))) { - return "PyTorch"; - } else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) { - return "TensorFlow"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + "-symbol.json"))) { - return "MXNet"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".onnx")) - || Files.isRegularFile(modelDir.resolve("model.onnx"))) { - return "OnnxRuntime"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".trt")) - || Files.isRegularFile(modelDir.resolve(modelName + ".uff"))) { - return "TensorRT"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".tflite"))) { - return "TFLite"; - } else if (Files.isRegularFile(modelDir.resolve("model")) - || Files.isRegularFile(modelDir.resolve("__model__")) - || Files.isRegularFile(modelDir.resolve("inference.pdmodel"))) { - return "PaddlePaddle"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".json"))) { - return "XGBoost"; - } - logger.warn("Failed to detect engine of the model: {}", modelDir); - return null; - } - private String[] parseDevices(String devices, Engine engine, Path modelDir) { if ("*".equals(devices)) { int gpuCount = engine.getGpuCount(); diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index 5738bdda8..1b46e8e18 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -461,6 +461,70 @@ public static String inferModelNameFromUrl(String url) { return modelName; } + /** + * Infers engine name from model URL. + * + * @param modelUrl the model URL + * @return the engine name + */ + public static String inferEngineFromUrl(String modelUrl) { + try { + Pair pair = ModelInfo.downloadModel(modelUrl); + return ModelInfo.inferEngine(pair.getValue(), pair.getKey()); + } catch (IOException e) { + logger.warn("Failed to extract model: " + modelUrl, e); + return null; + } + } + + /** + * Infers engine name from model directory. + * + * @param modelDir the model directory + * @param modelName the model name + * @return the engine name + */ + public static String inferEngine(Path modelDir, String modelName) { + modelDir = Utils.getNestedModelDir(modelDir); + + Properties prop = getServingProperties(modelDir); + String engine = prop.getProperty("engine"); + if (engine != null) { + return engine; + } + + modelName = prop.getProperty("option.modelName", modelName); + if (Files.isDirectory(modelDir.resolve("MAR-INF")) + || Files.isRegularFile(modelDir.resolve("model.py")) + || Files.isRegularFile(modelDir.resolve(modelName + ".py"))) { + // MMS/TorchServe + return "Python"; + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".pt")) + || Files.isRegularFile(modelDir.resolve("model.pt"))) { + return "PyTorch"; + } else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) { + return "TensorFlow"; + } else if (Files.isRegularFile(modelDir.resolve(modelName + "-symbol.json"))) { + return "MXNet"; + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".onnx")) + || Files.isRegularFile(modelDir.resolve("model.onnx"))) { + return "OnnxRuntime"; + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".trt")) + || Files.isRegularFile(modelDir.resolve(modelName + ".uff"))) { + return "TensorRT"; + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".tflite"))) { + return "TFLite"; + } else if (Files.isRegularFile(modelDir.resolve("model")) + || Files.isRegularFile(modelDir.resolve("__model__")) + || Files.isRegularFile(modelDir.resolve("inference.pdmodel"))) { + return "PaddlePaddle"; + } else if (Files.isRegularFile(modelDir.resolve(modelName + ".json"))) { + return "XGBoost"; + } + logger.warn("Failed to detect engine of the model: {}", modelDir); + return null; + } + /** * Returns the default device for this model if device is null. * @@ -468,6 +532,9 @@ public static String inferModelNameFromUrl(String url) { * @return a non-null device */ public Device withDefaultDevice(String deviceName) { + if (engineName == null && modelUrl != null) { + engineName = inferEngineFromUrl(modelUrl); + } Engine engine = engineName != null ? Engine.getEngine(engineName) : Engine.getInstance(); if (deviceName == null) { return engine.defaultDevice();