diff --git a/benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java b/benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java index 9b395a1248..b01ccb2207 100644 --- a/benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java +++ b/benchmark/src/main/java/ai/djl/benchmark/WlmBenchmark.java @@ -70,7 +70,8 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration) { Device device = Device.fromName(devices[0], engine); WorkLoadManager wlm = new WorkLoadManager(); Criteria criteria = loadModelCriteria(arguments, device); - ModelInfo modelInfo = new ModelInfo<>("model", criteria); + ModelInfo modelInfo = + new ModelInfo<>("model", arguments.getModelUrl(), criteria); WorkerPool wp = wlm.registerModel(modelInfo); int workersPerDevice = numOfWorkers / devices.length; diff --git a/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/src/main/java/ai/djl/serving/ModelServer.java index a50de97e18..67dd3413b1 100644 --- a/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -13,7 +13,6 @@ package ai.djl.serving; import ai.djl.engine.Engine; -import ai.djl.engine.EngineException; import ai.djl.metric.Dimension; import ai.djl.metric.Metric; import ai.djl.metric.Unit; @@ -31,8 +30,6 @@ import ai.djl.serving.workflow.BadWorkflowException; import ai.djl.serving.workflow.Workflow; import ai.djl.serving.workflow.WorkflowDefinition; -import ai.djl.util.NeuronUtils; -import ai.djl.util.Pair; import ai.djl.util.Utils; import io.netty.bootstrap.ServerBootstrap; @@ -65,7 +62,6 @@ import java.util.Collections; import java.util.List; import java.util.Objects; -import java.util.Properties; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -73,7 +69,6 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; -import java.util.stream.IntStream; import java.util.stream.Stream; /** The main entry point for model server. */ @@ -90,7 +85,6 @@ public class ModelServer { private ConfigManager configManager; private FolderScanPluginManager pluginManager; - private DependencyManager dependencyManager; /** * Creates a new {@code ModelServer} instance. @@ -101,7 +95,6 @@ public ModelServer(ConfigManager configManager) { this.configManager = configManager; this.pluginManager = new FolderScanPluginManager(configManager); serverGroups = new ServerGroups(configManager); - dependencyManager = DependencyManager.getInstance(); } /** @@ -361,7 +354,7 @@ private void initModelStore() throws IOException { String modelUrl = matcher.group(3); String version = null; String engineName = null; - String deviceMapping = "*"; + String deviceMapping = null; String modelName; if (endpoint != null) { String[] tokens = endpoint.split(":", -1); @@ -378,17 +371,6 @@ private void initModelStore() throws IOException { } else { modelName = ModelInfo.inferModelNameFromUrl(modelUrl); } - Pair pair = ModelInfo.downloadModel(modelUrl); - if (engineName == null) { - engineName = ModelInfo.inferEngine(pair.getValue(), pair.getKey()); - if (engineName == null) { - logger.warn("Failed to infer engine, skip url: {}", url); - continue; - } - } - dependencyManager.installEngine(engineName); - Engine engine = Engine.getEngine(engineName); - String[] devices = parseDevices(deviceMapping, engine, pair.getValue()); ModelInfo modelInfo = new ModelInfo<>( @@ -396,22 +378,19 @@ private void initModelStore() throws IOException { modelUrl, version, engineName, + deviceMapping, Input.class, Output.class, -1, -1, -1, + -1, + -1, -1); Workflow workflow = new Workflow(modelInfo); CompletableFuture f = modelManager .registerWorkflow(workflow) - .thenAccept( - v -> { - for (String deviceName : devices) { - modelManager.initWorkers(workflow, deviceName, -1, -1); - } - }) .exceptionally( t -> { logger.error("Failed register workflow", t); @@ -455,18 +434,10 @@ private void initWorkflows() throws IOException, URISyntaxException, BadWorkflow } String endpoint = matcher.group(2); String workflowUrlString = matcher.group(3); - String[] devices = {null}; String workflowName; if (endpoint != null) { String[] tokens = endpoint.split(":", -1); workflowName = tokens[0]; - if (tokens.length > 1) { - Pair pair = ModelInfo.downloadModel(workflowUrlString); - String engineName = ModelInfo.inferEngine(pair.getValue(), pair.getKey()); - dependencyManager.installEngine(engineName); - Engine engine = Engine.getEngine(engineName); - devices = parseDevices(tokens[1], engine, pair.getValue()); - } } else { workflowName = ModelInfo.inferModelNameFromUrl(workflowUrlString); } @@ -476,16 +447,9 @@ private void initWorkflows() throws IOException, URISyntaxException, BadWorkflow WorkflowDefinition.parse(workflowUrl.toURI(), workflowUrl.openStream()) .toWorkflow(); - String[] finalDevices = devices; CompletableFuture f = modelManager .registerWorkflow(workflow) - .thenAccept( - v -> { - for (String deviceName : finalDevices) { - modelManager.initWorkers(workflow, deviceName, -1, -1); - } - }) .exceptionally( t -> { logger.error("Failed register workflow", t); @@ -518,23 +482,7 @@ String mapModelUrl(Path path) { path = Utils.getNestedModelDir(path); String url = path.toUri().toURL().toString(); String modelName = ModelInfo.inferModelNameFromUrl(url); - String engine; - if (Files.isDirectory(path)) { - engine = ModelInfo.inferEngine(path, path.toFile().getName()); - } else { - // .zip file - engine = ModelInfo.inferEngineFromUrl(url); - Pair pair = ModelInfo.downloadModel(url); - path = pair.getValue(); - } - if (engine == null) { - return null; - } - String loadOnDevices = ModelInfo.inferDeviceName(url); - if (loadOnDevices == null) { - loadOnDevices = configManager.getLoadOnDevices(); - } - return modelName + "::" + engine + ':' + loadOnDevices + '=' + url; + return modelName + '=' + url; } catch (MalformedURLException e) { throw new AssertionError("Invalid path: " + path, e); } catch (IOException e) { @@ -543,60 +491,6 @@ String mapModelUrl(Path path) { } } - private String[] parseDevices(String devices, Engine engine, Path modelDir) { - if ("*".equals(devices)) { - int gpuCount = engine.getGpuCount(); - if (gpuCount > 0) { - String engineName = engine.getEngineName(); - if ("Python".equals(engineName)) { - Properties prop = ModelInfo.getServingProperties(modelDir); - String v = Utils.getenv("TENSOR_PARALLEL_DEGREE", "-1"); - v = prop.getProperty("option.tensor_parallel_degree", v); - int tensorParallelDegree = Integer.parseInt(v); - if (tensorParallelDegree > 0) { - int procs = gpuCount / tensorParallelDegree; - if (procs == 0) { - throw new EngineException( - "GPU devices are not enough to run " - + tensorParallelDegree - + " partitions."); - } - gpuCount = procs; - } - } else if ("DeepSpeed".equals(engineName) - || "FasterTransformer".equals(engineName)) { - return new String[] {"0"}; - } - - return IntStream.range(0, gpuCount) - .mapToObj(String::valueOf) - .toArray(String[]::new); - } else if (NeuronUtils.hasNeuron()) { - int neurons = NeuronUtils.getNeuronCores(); - Properties prop = ModelInfo.getServingProperties(modelDir); - String v = Utils.getenv("TENSOR_PARALLEL_DEGREE", "-1"); - v = prop.getProperty("option.tensor_parallel_degree", v); - int tensorParallelDegree = Integer.parseInt(v); - if (tensorParallelDegree > 0) { - // Assume user understand TP only works on inf2 - int procs = neurons / tensorParallelDegree; - if (procs == 0) { - throw new EngineException( - "Neuron devices are not enough to run " - + tensorParallelDegree - + " partitions. Please refer to: " - + "https://github.com/aws-neuron/transformers-neuronx#tensor-parallelism-support"); - } - neurons = procs; - } - return IntStream.range(0, neurons).mapToObj(i -> "nc" + i).toArray(String[]::new); - } - } else if (!devices.isEmpty()) { - return devices.split(";"); - } - return new String[] {null}; - } - private static void printHelp(String msg, Options options) { HelpFormatter formatter = new HelpFormatter(); formatter.setLeftPadding(1); diff --git a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java index c8ace3f236..52113e5918 100644 --- a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java @@ -235,18 +235,26 @@ private void predict( modelUrl, version, engineName, + deviceName, Input.class, Output.class, -1, -1, -1, + -1, + -1, -1); Workflow wf = new Workflow(modelInfo); modelManager .registerWorkflow(wf) - .thenApply(p -> modelManager.initWorkers(wf, deviceName, -1, -1)) - .thenAccept(p -> runJob(modelManager, ctx, p, input)); + .thenAccept(p -> runJob(modelManager, ctx, wf, input)) + .exceptionally( + t -> { + logger.error("Failed register workflow", t); + NettyUtils.sendError(ctx, t.getCause()); + return null; + }); return; } diff --git a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java index 4903d6e4f9..ddbb83c726 100644 --- a/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java @@ -198,27 +198,20 @@ private void handleRegisterModel( req.getModelUrl(), req.getVersion(), req.getEngineName(), + req.getDeviceName(), Input.class, Output.class, req.getJobQueueSize(), req.getMaxIdleSeconds(), req.getMaxBatchDelayMillis(), - req.getBatchSize()); + req.getBatchSize(), + req.getMinWorkers(), + req.getMaxWorkers()); Workflow workflow = new Workflow(modelInfo); final ModelManager modelManager = ModelManager.getInstance(); CompletableFuture f = modelManager .registerWorkflow(workflow) - .thenAccept( - v -> { - for (ModelInfo m : workflow.getModels()) { - modelManager.initWorkers( - m, - req.getDeviceName(), - req.getMinWorkers(), - req.getMaxWorkers()); - } - }) .exceptionally( t -> { NettyUtils.sendError(ctx, t.getCause()); @@ -244,9 +237,6 @@ private void handleRegisterWorkflow( throw new BadRequestException("Parameter url is required."); } - String deviceName = NettyUtils.getParameter(decoder, LoadModelRequest.DEVICE, null); - int minWorkers = NettyUtils.getIntParameter(decoder, LoadModelRequest.MIN_WORKER, -1); - int maxWorkers = NettyUtils.getIntParameter(decoder, LoadModelRequest.MAX_WORKER, -1); boolean synchronous = Boolean.parseBoolean( NettyUtils.getParameter(decoder, LoadModelRequest.SYNCHRONOUS, "true")); @@ -261,10 +251,6 @@ private void handleRegisterWorkflow( CompletableFuture f = modelManager .registerWorkflow(workflow) - .thenAccept( - v -> - modelManager.initWorkers( - workflow, deviceName, minWorkers, maxWorkers)) .exceptionally( t -> { NettyUtils.sendError(ctx, t.getCause()); diff --git a/serving/src/main/java/ai/djl/serving/models/ModelManager.java b/serving/src/main/java/ai/djl/serving/models/ModelManager.java index ee9ff32673..33d342e992 100644 --- a/serving/src/main/java/ai/djl/serving/models/ModelManager.java +++ b/serving/src/main/java/ai/djl/serving/models/ModelManager.java @@ -83,14 +83,21 @@ public CompletableFuture registerWorkflow(Workflow workflow) { () -> { for (ModelInfo model : workflow.getModels()) { try { + // download model and configure per model settings + model.initialize(); + // Install engine if necessary String engine = model.getEngineName(); - if (engine != null) { - DependencyManager dm = DependencyManager.getInstance(); - dm.installEngine(engine); - } + DependencyManager dm = DependencyManager.getInstance(); + dm.installEngine(engine); wlm.registerModel(model); - } catch (IOException e) { + + for (String deviceName : model.getLoadOnDevices()) { + int minWorkers = model.getMinWorkers(); + int maxWorkers = model.getMaxWorkers(); + modelManager.initWorkers(model, deviceName, minWorkers, maxWorkers); + } + } catch (IOException | ModelNotFoundException e) { throw new CompletionException(e); } } @@ -145,24 +152,6 @@ public boolean unregisterWorkflow(String workflowName, String version) { return true; } - /** - * Initializes the workers for each model in a workflow. - * - * @param workflow the workflow to scale workers for - * @param deviceName the device for the model - * @param minWorkers the min workers - * @param maxWorkers the max workers - * @return the info about the scaled workflow - * @see WorkerPool#initWorkers(String, int, int) - */ - public Workflow initWorkers( - Workflow workflow, String deviceName, int minWorkers, int maxWorkers) { - for (ModelInfo model : workflow.getModels()) { - initWorkers(model, deviceName, minWorkers, maxWorkers); - } - return workflow; - } - /** * Initializes the workers for a model. * @@ -264,7 +253,7 @@ public Set getStartupWorkflows() { * * @param workflow the workflow to run * @param input the input to the task - * @return {@code true} if submit success, false otherwise. + * @return the {@code CompletableFuture} */ public CompletableFuture runJob(Workflow workflow, Input input) { return workflow.execute(wlm, input); diff --git a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java index 48993adfdb..cd58dd1d2e 100644 --- a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java +++ b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java @@ -160,6 +160,7 @@ public static void init(Arguments args) { instance.withIntProperty(BATCH_SIZE, wlmc::setBatchSize); instance.withIntProperty(MAX_BATCH_DELAY, wlmc::setMaxBatchDelayMillis); instance.withIntProperty(RESERVED_MEMORY_MB, wlmc::setReservedMemoryMb); + wlmc.setLoadOnDevices(instance.getLoadOnDevices()); } /** diff --git a/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java b/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java index 71d06eb604..92cedee144 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java +++ b/serving/src/main/java/ai/djl/serving/workflow/WorkflowDefinition.java @@ -201,7 +201,7 @@ public ModelInfo deserialize( model.hasInputOutputClass(Input.class, Output.class); return model; } else if (json.isJsonPrimitive()) { - return new ModelInfo<>(json.getAsString(), Input.class, Output.class); + return new ModelInfo<>(json.getAsString()); } throw new JsonParseException( "Unexpected type of model definition: should be Criteria object or URI string"); diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 3ccb4af4e4..9df135e347 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -84,7 +84,6 @@ import org.testng.annotations.BeforeSuite; import org.testng.annotations.Test; -import java.io.BufferedWriter; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Type; @@ -169,91 +168,6 @@ public void afterMethod() { ModelManager.getInstance().clear(); } - @Test - public void testModelStore() - throws IOException, ServerStartupException, GeneralSecurityException, ParseException, - InterruptedException { - ModelServer server = initTestServer("src/test/resources/config.properties"); - try { - Path modelStore = Paths.get("build/models"); - Path modelDir = modelStore.resolve("test_model"); - Files.createDirectories(modelDir); - Path notModel = modelStore.resolve("non-model"); - Files.createFile(notModel); - - String url = server.mapModelUrl(notModel); // not a model dir - assertNull(url); - - url = server.mapModelUrl(modelDir); // empty folder - assertNull(url); - - String expected = modelDir.toUri().toURL().toString(); - - Path xgb = modelDir.resolve("test_model.json"); - Files.createFile(xgb); - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::XGBoost:*=" + expected); - - Path paddle = modelDir.resolve("__model__"); - Files.createFile(paddle); - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::PaddlePaddle:*=" + expected); - - Path tflite = modelDir.resolve("test_model.tflite"); - Files.createFile(tflite); - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::TFLite:*=" + expected); - - Path tensorRt = modelDir.resolve("test_model.uff"); - Files.createFile(tensorRt); - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::TensorRT:*=" + expected); - - Path onnx = modelDir.resolve("test_model.onnx"); - Files.createFile(onnx); - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::OnnxRuntime:*=" + expected); - - Path mxnet = modelDir.resolve("test_model-symbol.json"); - Files.createFile(mxnet); - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::MXNet:*=" + expected); - - Path tensorflow = modelDir.resolve("saved_model.pb"); - Files.createFile(tensorflow); - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::TensorFlow:*=" + expected); - - Path pytorch = modelDir.resolve("test_model.pt"); - Files.createFile(pytorch); - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::PyTorch:*=" + expected); - - Path prop = modelDir.resolve("serving.properties"); - try (BufferedWriter writer = Files.newBufferedWriter(prop)) { - writer.write("engine=MyEngine"); - } - url = server.mapModelUrl(modelDir); - assertEquals(url, "test_model::MyEngine:*=" + expected); - - Path mar = modelStore.resolve("torchServe.mar"); - Path torchServe = modelStore.resolve("torchServe"); - Files.createDirectories(torchServe.resolve("MAR-INF")); - Files.createDirectories(torchServe.resolve("code")); - ZipUtils.zip(torchServe, mar, false); - - url = server.mapModelUrl(mar); - assertEquals(url, "torchServe::Python:*=" + mar.toUri().toURL()); - - Path root = modelStore.resolve("models.pt"); - Files.createFile(root); - url = server.mapModelUrl(modelStore); - assertEquals(url, "models::PyTorch:*=" + modelStore.toUri().toURL()); - } finally { - server.stop(); - } - } - public static void main(String[] args) throws ReflectiveOperationException, ServerStartupException, GeneralSecurityException, ErrorDataEncoderException, IOException, ParseException, InterruptedException { @@ -270,6 +184,10 @@ public void test() ReflectiveOperationException, ServerStartupException { ModelServer server = initTestServer("src/test/resources/config.properties"); try { + Path notModel = Paths.get("build/non-model"); + String url = server.mapModelUrl(notModel); // not a model dir + assertNull(url); + assertTrue(server.isRunning()); Channel channel = initTestChannel(); diff --git a/serving/src/test/java/ai/djl/serving/WorkflowTest.java b/serving/src/test/java/ai/djl/serving/WorkflowTest.java index 850c36892c..ba3668a5a9 100644 --- a/serving/src/test/java/ai/djl/serving/WorkflowTest.java +++ b/serving/src/test/java/ai/djl/serving/WorkflowTest.java @@ -109,7 +109,7 @@ private Input runWorkflow(Path workflowFile, Input input) Workflow workflow = WorkflowDefinition.parse(workflowFile).toWorkflow(); WorkLoadManager wlm = new WorkLoadManager(); for (ModelInfo model : workflow.getModels()) { - wlm.registerModel(model).initWorkers(null, -1, 1); + wlm.registerModel(model).initWorkers("-1", -1, 1); } Output output = workflow.execute(wlm, input).join(); diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index 59a8e5c2f1..24e89de6c1 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -17,6 +17,9 @@ import ai.djl.Model; import ai.djl.ModelException; import ai.djl.engine.Engine; +import ai.djl.engine.EngineException; +import ai.djl.modality.Input; +import ai.djl.modality.Output; import ai.djl.repository.Artifact; import ai.djl.repository.FilenameUtils; import ai.djl.repository.MRL; @@ -29,7 +32,7 @@ import ai.djl.translate.ServingTranslator; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; -import ai.djl.util.Pair; +import ai.djl.util.NeuronUtils; import ai.djl.util.Utils; import ai.djl.util.cuda.CudaUtils; @@ -51,6 +54,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.IntStream; import java.util.stream.Stream; /** A class represent a loaded model and it's metadata. */ @@ -64,11 +68,14 @@ public final class ModelInfo { private String version; private String modelUrl; private String engineName; + private String loadOnDevices; private int queueSize; private int batchSize; private int maxBatchDelayMillis; private int maxIdleSeconds; + private int minWorkers = -1; + private int maxWorkers = -1; private Map filters; private Map arguments; @@ -77,6 +84,11 @@ public final class ModelInfo { private String modelName; private String translatorFactory; private String translator; + + private transient Path modelDir; + private transient String artifactName; + + private transient Properties prop; private transient Status status; private transient Class inputClass; @@ -87,25 +99,26 @@ public final class ModelInfo { /** * Constructs a new {@code ModelInfo} instance. * - * @param inputClass the model input class - * @param outputClass the model output class * @param modelUrl the model Url */ - public ModelInfo(String modelUrl, Class inputClass, Class outputClass) { + @SuppressWarnings("unchecked") + public ModelInfo(String modelUrl) { this.id = modelUrl; this.modelUrl = modelUrl; - this.inputClass = inputClass; - this.outputClass = outputClass; + this.inputClass = (Class) Input.class; + this.outputClass = (Class) Output.class; } /** * Constructs a {@link ModelInfo} based on a {@link Criteria}. * * @param id the id for the created {@link ModelInfo} + * @param modelUrl the model Url * @param criteria the model criteria */ - public ModelInfo(String id, Criteria criteria) { + public ModelInfo(String id, String modelUrl, Criteria criteria) { this.id = id; + this.modelUrl = modelUrl; this.criteria = criteria; inputClass = criteria.getInputClass(); outputClass = criteria.getOutputClass(); @@ -118,34 +131,43 @@ public ModelInfo(String id, Criteria criteria) { * @param modelUrl the model url * @param version the version of the model * @param engineName the engine to load the model + * @param loadOnDevices the devices to load the model on * @param inputClass the model input class * @param outputClass the model output class * @param queueSize the maximum request queue size - * @param maxIdleSeconds the initial maximum idle time for workers. - * @param maxBatchDelayMillis the initial maximum delay when scaling up before giving up. - * @param batchSize the batch size for this model. + * @param maxIdleSeconds the initial maximum idle time for workers + * @param maxBatchDelayMillis the initial maximum delay when scaling up before giving up + * @param batchSize the batch size for this model + * @param minWorkers the minimum number of workers + * @param maxWorkers the maximum number of workers */ public ModelInfo( String id, String modelUrl, String version, String engineName, + String loadOnDevices, Class inputClass, Class outputClass, int queueSize, int maxIdleSeconds, int maxBatchDelayMillis, - int batchSize) { + int batchSize, + int minWorkers, + int maxWorkers) { this.id = id; this.modelUrl = modelUrl; this.version = version; this.engineName = engineName; + this.loadOnDevices = loadOnDevices; this.inputClass = inputClass; this.outputClass = outputClass; this.maxBatchDelayMillis = maxBatchDelayMillis; this.maxIdleSeconds = maxIdleSeconds; // default max idle time 60s this.queueSize = queueSize; this.batchSize = batchSize; + this.minWorkers = minWorkers; + this.maxWorkers = maxWorkers; } /** @@ -161,21 +183,19 @@ public void load(Device device) throws ModelException, IOException { return; } + try { + // Download the model again if the model files are deleted + initialize(); + checkAvailableMemory(device); + } catch (IOException e) { + throw new ModelNotFoundException(e); + } + try { Criteria.Builder builder; if (criteria != null) { builder = criteria.toBuilder(); } else { - // Download the model first, and get model specific configuration - // batchSize is required before model loading in dynamic batching case - try { - Pair pair = downloadModel(modelUrl); - checkAvailableMemory(device, pair.getValue()); - configPerModelSettings(pair.getValue()); - } catch (IOException e) { - throw new ModelNotFoundException(e); - } - builder = Criteria.builder() .setTypes(inputClass, outputClass) @@ -216,10 +236,6 @@ public void load(Device device) throws ModelException, IOException { } ZooModel m = builder.build().loadModel(); - if (criteria != null) { - // TODO: user has to manually configure batchifier if using dynamic batch - configPerModelSettings(m.getModelPath()); - } models.put(device, m); status = Status.READY; } finally { @@ -425,6 +441,39 @@ public int getQueueSize() { return queueSize; } + /** + * Returns the minimum number of workers. + * + * @return the minimum number of workers + */ + public int getMinWorkers() { + return minWorkers; + } + + /** + * Returns the maximum number of workers. + * + * @return the maximum number of workers + */ + public int getMaxWorkers() { + return maxWorkers; + } + + /** + * Initialize the model. + * + * @throws IOException if failed to download model + * @throws ModelNotFoundException if model not found + */ + public void initialize() throws IOException, ModelNotFoundException { + downloadModel(); + loadServingProperties(); + if (engineName == null) { + engineName = inferEngine(); + } + configPerModelSettings(); + } + /** Close all loaded models. */ public void close() { if (!getModels().isEmpty()) { @@ -473,152 +522,92 @@ public static String inferModelNameFromUrl(String url) { } /** - * Infers engine name from model URL. - * - * @param modelUrl the model URL - * @return the engine name - */ - public static String inferEngineFromUrl(String modelUrl) { - try { - Pair pair = downloadModel(modelUrl); - return inferEngine(pair.getValue(), pair.getKey()); - } catch (IOException e) { - logger.warn("Failed to extract model: " + modelUrl, e); - return null; - } - } - - /** - * Infers which device to load. + * Returns the default device for this model if device is null. * - * @param modelUrl the model URL - * @return the device name + * @param deviceName the device to use if it is not null + * @return a non-null device */ - public static String inferDeviceName(String modelUrl) { - try { - Pair pair = downloadModel(modelUrl); - Properties prop = getServingProperties(pair.getValue()); - return prop.getProperty("load_on_devices"); - } catch (IOException e) { - logger.warn("Failed to extract model: " + modelUrl, e); - return null; - } + public Device withDefaultDevice(String deviceName) { + return Device.fromName(deviceName, Engine.getEngine(engineName)); } - /** - * Infers engine name from model directory. - * - * @param modelDir the model directory - * @param modelName the model name - * @return the engine name - */ - public static String inferEngine(Path modelDir, String modelName) { - modelDir = Utils.getNestedModelDir(modelDir); - - Properties prop = getServingProperties(modelDir); + private String inferEngine() throws ModelNotFoundException { String engine = prop.getProperty("engine"); if (engine != null) { return engine; } - modelName = prop.getProperty("option.modelName", modelName); + String prefix = prop.getProperty("option.modelName", artifactName); if (Files.isDirectory(modelDir.resolve("MAR-INF")) || Files.isRegularFile(modelDir.resolve("model.py")) - || Files.isRegularFile(modelDir.resolve(modelName + ".py"))) { + || Files.isRegularFile(modelDir.resolve(prefix + ".py"))) { // MMS/TorchServe return "Python"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".pt")) + } else if (Files.isRegularFile(modelDir.resolve(prefix + ".pt")) || Files.isRegularFile(modelDir.resolve("model.pt"))) { return "PyTorch"; } else if (Files.isRegularFile(modelDir.resolve("config.pbtxt"))) { return "TritonServer"; } else if (Files.isRegularFile(modelDir.resolve("saved_model.pb"))) { return "TensorFlow"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + "-symbol.json"))) { + } else if (Files.isRegularFile(modelDir.resolve(prefix + "-symbol.json"))) { return "MXNet"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".onnx")) + } else if (Files.isRegularFile(modelDir.resolve(prefix + ".onnx")) || Files.isRegularFile(modelDir.resolve("model.onnx"))) { return "OnnxRuntime"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".trt")) - || Files.isRegularFile(modelDir.resolve(modelName + ".uff"))) { + } else if (Files.isRegularFile(modelDir.resolve(prefix + ".trt")) + || Files.isRegularFile(modelDir.resolve(prefix + ".uff"))) { return "TensorRT"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".tflite"))) { + } else if (Files.isRegularFile(modelDir.resolve(prefix + ".tflite"))) { return "TFLite"; } else if (Files.isRegularFile(modelDir.resolve("model")) || Files.isRegularFile(modelDir.resolve("__model__")) || Files.isRegularFile(modelDir.resolve("inference.pdmodel"))) { return "PaddlePaddle"; - } else if (Files.isRegularFile(modelDir.resolve(modelName + ".json"))) { + } else if (Files.isRegularFile(modelDir.resolve(prefix + ".json"))) { return "XGBoost"; + } else { + try { + if (Utils.getCurrentEpoch(modelDir, prefix) >= 0) { + // Assume this is DJL model + return Engine.getDefaultEngineName(); + } + } catch (IOException e) { + logger.warn("Failed search parameter files in folder: " + modelDir, e); + } } - logger.warn("Failed to detect engine of the model: {}", modelDir); - return null; - } - - /** - * Returns the default device for this model if device is null. - * - * @param deviceName the device to use if it is not null - * @return a non-null device - */ - public Device withDefaultDevice(String deviceName) { - if (engineName == null && modelUrl != null) { - engineName = inferEngineFromUrl(modelUrl); - } - if (deviceName == null && modelUrl != null) { - deviceName = inferDeviceName(modelUrl); - } - Engine engine = engineName != null ? Engine.getEngine(engineName) : Engine.getInstance(); - // TODO: Load model API doesn't support * or multiple devices - if (deviceName == null || "*".equals(deviceName)) { - return engine.defaultDevice(); - } - String[] devices = deviceName.split(";"); - return Device.fromName(devices[0], engine); + throw new ModelNotFoundException("Failed to detect engine of the model: " + modelDir); } - /** - * Downloads model from the model URL. - * - * @param modelUrl the model URL - * @return model name and downloaded model path - * @throws IOException if failed to download the model - */ - public static Pair downloadModel(String modelUrl) throws IOException { + private void downloadModel() throws ModelNotFoundException, IOException { Repository repository = Repository.newInstance("modelStore", modelUrl); List mrls = repository.getResources(); if (mrls.isEmpty()) { - throw new IOException("Invalid model url: " + modelUrl); + throw new ModelNotFoundException("Invalid model url: " + modelUrl); } Artifact artifact = mrls.get(0).getDefaultArtifact(); repository.prepare(artifact); - Path modelDir = Utils.getNestedModelDir(repository.getResourceDirectory(artifact)); - return new Pair<>(artifact.getName(), modelDir); + modelDir = Utils.getNestedModelDir(repository.getResourceDirectory(artifact)); + artifactName = artifact.getName(); } - /** - * Loads the serving properties from model folder. - * - * @param modelDir model directory - * @return the serving properties - */ - public static Properties getServingProperties(Path modelDir) { - Path file = modelDir.resolve("serving.properties"); - Properties prop = new Properties(); - if (Files.isRegularFile(file)) { - try (InputStream is = Files.newInputStream(file)) { - prop.load(is); - } catch (IOException e) { - logger.warn("Failed read serving.properties file", e); + private void loadServingProperties() { + if (prop == null) { + Path file = modelDir.resolve("serving.properties"); + prop = new Properties(); + if (Files.isRegularFile(file)) { + try (InputStream is = Files.newInputStream(file)) { + prop.load(is); + } catch (IOException e) { + logger.warn("Failed read serving.properties file", e); + } } } - return prop; } - private void configPerModelSettings(Path modelDir) throws IOException { + private void configPerModelSettings() { // per model settings can only be configured once - Properties prop = getServingProperties(modelDir); WlmConfigManager wlmc = WlmConfigManager.getInstance(); if (queueSize <= 0) { queueSize = intValue(prop, "job_queue_size", wlmc.getJobQueueSize()); @@ -632,14 +621,24 @@ private void configPerModelSettings(Path modelDir) throws IOException { if (maxIdleSeconds <= 0) { maxIdleSeconds = intValue(prop, "max_idle_time", wlmc.getMaxIdleSeconds()); } + if (loadOnDevices == null) { + loadOnDevices = prop.getProperty("load_on_devices", wlmc.getLoadOnDevices()); + } + logger.info( + "Apply per model settings:\n\tqueueSize: {}\n\tbatchSize: {}" + + "\n\tmaxBatchDelay: {}\n\tmaxIdle: {}\n\tloadOnDevices: {}", + queueSize, + batchSize, + maxBatchDelayMillis, + maxIdleSeconds, + loadOnDevices); } - void checkAvailableMemory(Device device, Path modelDir) throws IOException { + void checkAvailableMemory(Device device) throws IOException { if (Boolean.getBoolean("skip_oom_check")) { return; } - Properties prop = getServingProperties(modelDir); long requiredMemory = intValue(prop, "required_memory_mb", 0) * 1024L * 1024; WlmConfigManager wlmc = WlmConfigManager.getInstance(); int defMemory = wlmc.getReservedMemoryMb(); @@ -688,6 +687,63 @@ void checkAvailableMemory(Device device, Path modelDir) throws IOException { } } + /** + * Returns the devices the model will be loaded on at startup. + * + * @return the devices the model will be loaded on at startup + */ + public String[] getLoadOnDevices() { + Engine engine = Engine.getEngine(engineName); + if ("*".equals(loadOnDevices)) { + int gpuCount = engine.getGpuCount(); + if (gpuCount > 0) { + if ("Python".equals(engineName)) { + String v = Utils.getenv("TENSOR_PARALLEL_DEGREE", "-1"); + v = prop.getProperty("option.tensor_parallel_degree", v); + int tensorParallelDegree = Integer.parseInt(v); + if (tensorParallelDegree > 0) { + int procs = gpuCount / tensorParallelDegree; + if (procs == 0) { + throw new EngineException( + "GPU devices are not enough to run " + + tensorParallelDegree + + " partitions."); + } + gpuCount = procs; + } + } else if ("DeepSpeed".equals(engineName) + || "FasterTransformer".equals(engineName)) { + return new String[] {"0"}; + } + + return IntStream.range(0, gpuCount) + .mapToObj(String::valueOf) + .toArray(String[]::new); + } else if (NeuronUtils.hasNeuron()) { + int neurons = NeuronUtils.getNeuronCores(); + String v = Utils.getenv("TENSOR_PARALLEL_DEGREE", "-1"); + v = prop.getProperty("option.tensor_parallel_degree", v); + int tensorParallelDegree = Integer.parseInt(v); + if (tensorParallelDegree > 0) { + // Assume user understand TP only works on inf2 + int procs = neurons / tensorParallelDegree; + if (procs == 0) { + throw new EngineException( + "Neuron devices are not enough to run " + + tensorParallelDegree + + " partitions. Please refer to: " + + "https://github.com/aws-neuron/transformers-neuronx#tensor-parallelism-support"); + } + neurons = procs; + } + return IntStream.range(0, neurons).mapToObj(i -> "nc" + i).toArray(String[]::new); + } + } else if (!loadOnDevices.isEmpty()) { + return loadOnDevices.split(";"); + } + return new String[] {"-1"}; + } + private static long getFileSize(Path path) { try { return Files.size(path); @@ -697,7 +753,7 @@ private static long getFileSize(Path path) { return 0L; } - private static long getAvailableCpuMemory() { + private long getAvailableCpuMemory() { if (System.getProperty("os.name").startsWith("Linux")) { try (Scanner scanner = new Scanner(Paths.get("/proc/meminfo"))) { while (scanner.hasNext()) { diff --git a/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java b/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java index ef0896e942..8319abf57b 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/WorkerPool.java @@ -124,10 +124,13 @@ public boolean isFullyScaled() { * @param maxWorkers maximum amount of workers. */ public void initWorkers(String deviceName, int minWorkers, int maxWorkers) { - Device device = model.withDefaultDevice(deviceName); - logger.info("initWorkers for {} ({}): {}, {}", model, device, minWorkers, maxWorkers); + Device device; synchronized (model) { try { + model.initialize(); + device = model.withDefaultDevice(deviceName); + logger.info( + "initWorkers for {} ({}): {}, {}", model, device, minWorkers, maxWorkers); model.load(device); } catch (ModelException | IOException e) { throw new CompletionException(e); diff --git a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java index e89c607bfd..46824d6bbd 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/util/WlmConfigManager.java @@ -24,6 +24,7 @@ public final class WlmConfigManager { private int batchSize = 1; private int maxBatchDelayMillis = 100; private int reservedMemoryMb = 500; + private String loadOnDevices; private static final WlmConfigManager INSTANCE = new WlmConfigManager(); @@ -137,6 +138,24 @@ public void setReservedMemoryMb(int reservedMemoryMb) { this.reservedMemoryMb = reservedMemoryMb; } + /** + * Returns the devices the model will be loaded on at startup. + * + * @return the devices the model will be loaded on at startup + */ + public String getLoadOnDevices() { + return loadOnDevices; + } + + /** + * Sets the devices the model will be loaded on at startup. + * + * @param loadOnDevices thes the default model will be loaded on at startup + */ + public void setLoadOnDevices(String loadOnDevices) { + this.loadOnDevices = loadOnDevices; + } + /** * Returns the default minimum number of workers for a new registered model. * diff --git a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java index 0ac5b3555d..9a5130357d 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java @@ -12,6 +12,8 @@ */ package ai.djl.serving.wlm; +import static org.testng.Assert.assertEquals; + import ai.djl.Device; import ai.djl.ModelException; import ai.djl.engine.Engine; @@ -19,13 +21,18 @@ import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.TranslateException; import ai.djl.util.Utils; +import ai.djl.util.ZipUtils; import org.testng.Assert; +import org.testng.annotations.AfterSuite; +import org.testng.annotations.BeforeSuite; import org.testng.annotations.Test; +import java.io.BufferedWriter; import java.io.IOException; import java.io.InputStream; import java.io.Writer; @@ -37,11 +44,39 @@ public class ModelInfoTest { + @BeforeSuite + public void beforeSuite() throws IOException { + Path modelStore = Paths.get("build/models"); + Utils.deleteQuietly(modelStore); + Files.createDirectories(modelStore); + String engineCacheDir = Utils.getEngineCacheDir().toString(); + System.setProperty("DJL_CACHE_DIR", "build/cache"); + System.setProperty("ENGINE_CACHE_DIR", engineCacheDir); + } + + @AfterSuite + public void afterSuite() { + System.clearProperty("DJL_CACHE_DIR"); + System.clearProperty("ENGINE_CACHE_DIR"); + } + @Test public void testQueueSizeIsSet() { ModelInfo modelInfo = new ModelInfo<>( - "", null, null, "PyTorch", Input.class, Output.class, 4711, 1, 300, 1); + "", + null, + null, + "PyTorch", + null, + Input.class, + Output.class, + 4711, + 1, + 300, + 1, + -1, + -1); Assert.assertEquals(4711, modelInfo.getQueueSize()); Assert.assertEquals(1, modelInfo.getMaxIdleSeconds()); Assert.assertEquals(300, modelInfo.getMaxBatchDelayMillis()); @@ -56,7 +91,7 @@ public void testCriteriaModelInfo() throws ModelException, IOException, Translat .setTypes(Input.class, Output.class) .optModelUrls(modelUrl) .build(); - ModelInfo modelInfo = new ModelInfo<>("model", criteria); + ModelInfo modelInfo = new ModelInfo<>("model", modelUrl, criteria); modelInfo.load(Device.cpu()); try (ZooModel model = modelInfo.getModel(Device.cpu()); Predictor predictor = model.newPredictor()) { @@ -70,33 +105,40 @@ public void testCriteriaModelInfo() throws ModelException, IOException, Translat } @Test - public void testOutOfMemory() throws IOException { + public void testOutOfMemory() throws IOException, ModelNotFoundException { Path modelDir = Paths.get("build/oom_model"); Utils.deleteQuietly(modelDir); Files.createDirectories(modelDir); - ModelInfo modelInfo = + + ModelInfo modelInfo = new ModelInfo<>( "", "build/oom_model", null, "PyTorch", + "nc1,nc2", Input.class, Output.class, 4711, 1, 300, - 1); - + 1, + -1, + -1); + modelInfo.initialize(); Device device = Engine.getInstance().defaultDevice(); - modelInfo.checkAvailableMemory(device, modelDir); + modelInfo.checkAvailableMemory(device); Path file = modelDir.resolve("serving.properties"); Properties prop = new Properties(); prop.setProperty("reserved_memory_mb", String.valueOf(Integer.MAX_VALUE)); + prop.setProperty("engine", "PyTorch"); try (Writer writer = Files.newBufferedWriter(file)) { prop.store(writer, ""); } - Assert.assertThrows(() -> modelInfo.checkAvailableMemory(Device.cpu(), modelDir)); + ModelInfo m1 = new ModelInfo<>("build/oom_model"); + m1.initialize(); + Assert.assertThrows(() -> m1.checkAvailableMemory(Device.cpu())); if (device.isGpu()) { prop.setProperty("required_memory_mb", "1"); @@ -107,7 +149,100 @@ public void testOutOfMemory() throws IOException { prop.store(writer, ""); } - Assert.assertThrows(() -> modelInfo.checkAvailableMemory(device, modelDir)); + ModelInfo m2 = new ModelInfo<>("build/oom_model"); + m2.initialize(); + Assert.assertThrows(() -> m2.checkAvailableMemory(device)); } } + + @Test + public void testInitModel() throws IOException, ModelNotFoundException { + Path modelStore = Paths.get("build/models"); + Path modelDir = modelStore.resolve("test_model"); + Files.createDirectories(modelDir); + Path notModel = modelStore.resolve("non-model"); + + ModelInfo model = new ModelInfo<>(notModel.toUri().toURL().toString()); + Assert.assertThrows(model::initialize); + + model = new ModelInfo<>("build/models/test_model"); + Assert.assertThrows(model::initialize); + + Path xgb = modelDir.resolve("test_model.json"); + Files.createFile(xgb); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "XGBoost"); + + Path paddle = modelDir.resolve("__model__"); + Files.createFile(paddle); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "PaddlePaddle"); + + Path tflite = modelDir.resolve("test_model.tflite"); + Files.createFile(tflite); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "TFLite"); + + Path tensorRt = modelDir.resolve("test_model.uff"); + Files.createFile(tensorRt); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "TensorRT"); + + Path onnx = modelDir.resolve("test_model.onnx"); + Files.createFile(onnx); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "OnnxRuntime"); + + Path mxnet = modelDir.resolve("test_model-symbol.json"); + Files.createFile(mxnet); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "MXNet"); + + Path tensorflow = modelDir.resolve("saved_model.pb"); + Files.createFile(tensorflow); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "TensorFlow"); + + Path triton = modelDir.resolve("config.pbtxt"); + Files.createFile(triton); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "TritonServer"); + + Path pytorch = modelDir.resolve("test_model.pt"); + Files.createFile(pytorch); + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "PyTorch"); + + Path prop = modelDir.resolve("serving.properties"); + try (BufferedWriter writer = Files.newBufferedWriter(prop)) { + writer.write("engine=MyEngine"); + } + model = new ModelInfo<>("build/models/test_model"); + model.initialize(); + assertEquals(model.getEngineName(), "MyEngine"); + + Path mar = modelStore.resolve("torchServe.mar"); + Path torchServe = modelStore.resolve("torchServe"); + Files.createDirectories(torchServe.resolve("MAR-INF")); + Files.createDirectories(torchServe.resolve("code")); + ZipUtils.zip(torchServe, mar, false); + model = new ModelInfo<>(mar.toUri().toURL().toString()); + model.initialize(); + assertEquals(model.getEngineName(), "Python"); + + Path root = modelStore.resolve("models.pt"); + Files.createFile(root); + model = new ModelInfo<>("build/models"); + model.initialize(); + assertEquals(model.getEngineName(), "PyTorch"); + } } diff --git a/wlm/src/test/java/ai/djl/serving/wlm/WorkLoadManagerTest.java b/wlm/src/test/java/ai/djl/serving/wlm/WorkLoadManagerTest.java index 103b5e4089..f569471a6c 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/WorkLoadManagerTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/WorkLoadManagerTest.java @@ -37,7 +37,7 @@ public void testFromCriteria() throws IOException { .setTypes(Input.class, Output.class) .optModelUrls(modelUrl) .build(); - ModelInfo modelInfo = new ModelInfo<>("model", criteria); + ModelInfo modelInfo = new ModelInfo<>("model", modelUrl, criteria); wlm.registerModel(modelInfo).initWorkers(null, 1, 2); Input input = new Input(); URL url = new URL("https://resources.djl.ai/images/0.png");