From 2bf6bae3f9253a124f0b35730c5d634ddbbb82a9 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Wed, 20 Mar 2024 10:40:08 -0700 Subject: [PATCH] Adds adapter registration options (#1634) --- .../python/src/test/resources/adaptecho/model.py | 14 +++++++++----- serving/docs/adapters.md | 2 ++ serving/docs/adapters_api.md | 1 + .../http/AdapterManagementRequestHandler.java | 12 +++++++++++- .../function/AdapterWorkflowFunction.java | 15 +++++++++++++-- .../test/java/ai/djl/serving/ModelServerTest.java | 12 ++++++------ .../resources/adapterWorkflows/w1/workflow.json | 5 ++++- wlm/src/main/java/ai/djl/serving/wlm/Adapter.java | 12 +++++++++--- .../main/java/ai/djl/serving/wlm/ModelInfo.java | 4 +++- .../main/java/ai/djl/serving/wlm/PyAdapter.java | 10 ++++++++-- 10 files changed, 66 insertions(+), 21 deletions(-) diff --git a/engines/python/src/test/resources/adaptecho/model.py b/engines/python/src/test/resources/adaptecho/model.py index e97540fa5..6941cf3b9 100644 --- a/engines/python/src/test/resources/adaptecho/model.py +++ b/engines/python/src/test/resources/adaptecho/model.py @@ -27,7 +27,7 @@ def register_adapter(inputs: Input): global adapters name = inputs.get_properties()["name"] - adapters[name] = True + adapters[name] = inputs return Output().add("Successfully registered adapter") @@ -51,13 +51,17 @@ def handle(inputs: Input): for i, input in enumerate(inputs.get_batches()): data = input.get_as_string() if input.contains_key("adapter"): - adapter = input.get_as_string("adapter") - if adapter in adapters: + adapter_name = input.get_as_string("adapter") + if adapter_name in adapters: + adapter = adapters[adapter_name] + option = "" + if adapter.contains_key("echooption"): + option = adapter.get_as_string("echooption") # Registered adapter - out = adapter + data + out = adapter_name + option + data else: # Dynamic adapter - out = "dyn" + adapter + data + out = "dyn" + adapter_name + data else: out = data outputs.add(out, key="data", batch_index=i) diff --git a/serving/docs/adapters.md b/serving/docs/adapters.md index 18a5cb441..4187ce702 100644 --- a/serving/docs/adapters.md +++ b/serving/docs/adapters.md @@ -58,6 +58,7 @@ For the simple model + adapter case, you can also directly use the adapter [work With our workflows, multiple workflows sharing models will be de-duplicated. So, the effect of having multiple adapters can be easily made with having one workflow for each adapter. This system can be used on [Amazon SageMaker Multi-Model Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html). +More details can be found on the [AdapterWorkflowFunction docs](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/workflow/function/AdapterWorkflowFunction.html). ``` workflow.json: @@ -165,6 +166,7 @@ These can then take the adapters and save the src, pre-download it, cache it in def register_adapter(inputs: Input): name = inputs.get_properties()["name"] src = inputs.get_properties()["src"] + options = inputs.get_content() # Do adapter registration tasks return Output().add("Successfully registered adapter") diff --git a/serving/docs/adapters_api.md b/serving/docs/adapters_api.md index 06047ef0d..59039ffad 100644 --- a/serving/docs/adapters_api.md +++ b/serving/docs/adapters_api.md @@ -19,6 +19,7 @@ This is an extension of the [Management API](management_api.md) and can be acces * name - The adapter name. * src - The adapter src. It currently requires a file, but eventually an id or URL can be supported depending on the model handler. +* All additional arguments will be treated as additional model-specific options and will be passed to the model during adapter registration ```bash curl -X POST "http://localhost:8080/models/adaptecho/adapters?name=a1&src=..." diff --git a/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java index c0b3abfe9..e7ff02cbc 100644 --- a/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/AdapterManagementRequestHandler.java @@ -30,6 +30,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Pattern; /** A class handling inbound HTTP requests to the management API for adapters. */ @@ -154,7 +156,15 @@ private void handleRegisterAdapter( if (wp == null) { throw new BadRequestException("The model " + modelName + " was not found"); } - Adapter adapter = Adapter.newInstance(wp.getWpc(), adapterName, src); + + Map options = new ConcurrentHashMap<>(); + for (Map.Entry> entry : decoder.parameters().entrySet()) { + if (entry.getValue().size() == 1) { + options.put(entry.getKey(), entry.getValue().get(0)); + } + } + + Adapter adapter = Adapter.newInstance(wp.getWpc(), adapterName, src, options); adapter.register(wp); String msg = "Adapter " + adapterName + " registered"; diff --git a/serving/src/main/java/ai/djl/serving/workflow/function/AdapterWorkflowFunction.java b/serving/src/main/java/ai/djl/serving/workflow/function/AdapterWorkflowFunction.java index 52a1fdada..e2400f389 100644 --- a/serving/src/main/java/ai/djl/serving/workflow/function/AdapterWorkflowFunction.java +++ b/serving/src/main/java/ai/djl/serving/workflow/function/AdapterWorkflowFunction.java @@ -32,12 +32,13 @@ * *

To use this workflow function, you must pre-specify the adapted functions in the configs. In * the configs, create an object "adapters" with keys as reference names and values as objects. The - * adapter reference objects should have three properties: + * adapter reference objects should have the following properties: * *

    *
  • model - the model name *
  • name - the adapter name *
  • url - the adapter url + *
  • options (optional) - an object containing additional string options *
* *

To call this workflow function, it requires two arguments. The first is the adapter config @@ -69,8 +70,18 @@ public void prepare(WorkLoadManager wlm, Map> config String adapterName = entry.getKey(); String src = (String) config.get("src"); + Map options = new ConcurrentHashMap<>(); + if (config.containsKey("options") && config.get("options") instanceof Map) { + for (Map.Entry option : + ((Map) config.get("options")).entrySet()) { + if (option.getValue() instanceof String) { + options.put(option.getKey(), (String) option.getValue()); + } + } + } + WorkerPool wp = wlm.getWorkerPoolById(modelName); - Adapter adapter = Adapter.newInstance(wp.getWpc(), adapterName, src); + Adapter adapter = Adapter.newInstance(wp.getWpc(), adapterName, src, options); adapters.put(adapterName, new AdapterReference(modelName, adapter)); } } diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 3188251d1..f032045bc 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -319,7 +319,7 @@ public void testAdapterWorkflows() assertTrue(server.isRunning()); Channel channel = initTestChannel(); - testAdapterWorkflowPredict(channel, "adapter1", "a1"); + testAdapterWorkflowPredict(channel, "adapter1", "a1weo"); testAdapterWorkflowPredict(channel, "adapter2", "a2"); testRegisterAdapterWorkflowTemplate(channel); @@ -868,7 +868,7 @@ private void testAdapterRegister(Channel channel, boolean registerModel, boolean testAdapterMissing(); String strModelPrefix = modelPrefix ? "/models/adaptecho" : ""; - url = strModelPrefix + "/adapters?name=" + "adaptable" + "&src=" + "src"; + url = strModelPrefix + "/adapters?name=adaptable&src=src&echooption=opt"; request(channel, HttpMethod.POST, url); assertHttpOk(); } @@ -926,7 +926,7 @@ private void testAdapterPredict(Channel channel) throws InterruptedException { req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); request(channel, req); assertHttpOk(); - assertEquals(result, "adaptabletestPredictAdapter"); + assertEquals(result, "adaptableopttestPredictAdapter"); } private void testAdapterDirPredict(Channel channel) throws InterruptedException { @@ -942,7 +942,7 @@ private void testAdapterDirPredict(Channel channel) throws InterruptedException assertEquals(result, "myBuiltinAdaptertestPredictBuiltinAdapter"); } - private void testAdapterWorkflowPredict(Channel channel, String workflow, String adapter) + private void testAdapterWorkflowPredict(Channel channel, String workflow, String prefix) throws InterruptedException { logTestFunction(); String url = "/predictions/" + workflow; @@ -953,7 +953,7 @@ private void testAdapterWorkflowPredict(Channel channel, String workflow, String req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); request(channel, req); assertHttpOk(); - assertEquals(result, adapter + "testAWP"); + assertEquals(result, prefix + "testAWP"); } private void testRegisterAdapterWorkflowTemplate(Channel channel) throws InterruptedException { @@ -980,7 +980,7 @@ private void testAdapterInvoke(Channel channel) throws InterruptedException { request(channel, req); assertHttpOk(); - assertEquals(result, "adaptabletestInvokeAdapter"); + assertEquals(result, "adaptableopttestInvokeAdapter"); } private void testAdapterList(Channel channel, boolean modelPrefix) throws InterruptedException { diff --git a/serving/src/test/resources/adapterWorkflows/w1/workflow.json b/serving/src/test/resources/adapterWorkflows/w1/workflow.json index 4004bfd88..4c6c24157 100644 --- a/serving/src/test/resources/adapterWorkflows/w1/workflow.json +++ b/serving/src/test/resources/adapterWorkflows/w1/workflow.json @@ -8,7 +8,10 @@ "adapters": { "a1": { "model": "m", - "src": "url1" + "src": "url1", + "options": { + "echooption": "weo" + } } } }, diff --git a/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java b/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java index 357c42778..81d626c72 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/Adapter.java @@ -18,6 +18,7 @@ import java.net.URI; import java.net.URISyntaxException; +import java.util.Map; import java.util.concurrent.CompletableFuture; /** @@ -27,16 +28,19 @@ public abstract class Adapter { protected String name; protected String src; + protected Map options; /** * Constructs an {@link Adapter}. * * @param name the adapter name * @param src the adapter source + * @param options additional adapter options */ - protected Adapter(String name, String src) { + protected Adapter(String name, String src, Map options) { this.name = name; this.src = src; + this.options = options; } /** @@ -48,9 +52,11 @@ protected Adapter(String name, String src) { * @param wpc the worker pool config for the new adapter * @param name the adapter name * @param src the adapter source + * @param options additional adapter options * @return the new adapter */ - public static Adapter newInstance(WorkerPoolConfig wpc, String name, String src) { + public static Adapter newInstance( + WorkerPoolConfig wpc, String name, String src, Map options) { if (!(wpc instanceof ModelInfo)) { String modelName = wpc.getId(); throw new IllegalArgumentException("The worker " + modelName + " is not a model"); @@ -69,7 +75,7 @@ public static Adapter newInstance(WorkerPoolConfig wpc, String name, Strin ModelInfo modelInfo = (ModelInfo) wpc; // TODO Replace usage of class name with creating adapters by Engine.newPatch(name ,src) if ("PyEngine".equals(modelInfo.getEngine().getClass().getSimpleName())) { - return new PyAdapter(name, src); + return new PyAdapter(name, src, options); } else { throw new IllegalArgumentException( "Adapters are only currently supported for Python models"); diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index babe551fd..40efb4680 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -51,6 +51,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; @@ -282,7 +283,8 @@ public void load(Device device) throws ModelException, IOException { Adapter.newInstance( this, adapterName, - adapterDir.toAbsolutePath().toString()); + adapterDir.toAbsolutePath().toString(), + Collections.emptyMap()); registerAdapter(adapter); long d = (System.nanoTime() - start) / 1000; Metric me = diff --git a/wlm/src/main/java/ai/djl/serving/wlm/PyAdapter.java b/wlm/src/main/java/ai/djl/serving/wlm/PyAdapter.java index 7863ce8b7..ec03a6135 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/PyAdapter.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/PyAdapter.java @@ -17,6 +17,8 @@ import ai.djl.modality.Output; import ai.djl.translate.TranslateException; +import java.util.Map; + /** An overload of {@link Adapter} for the python engine. */ public class PyAdapter extends Adapter { @@ -25,9 +27,10 @@ public class PyAdapter extends Adapter { * * @param name the adapter name * @param src the adapter src + * @param options additional adapter options */ - protected PyAdapter(String name, String src) { - super(name, src); + protected PyAdapter(String name, String src, Map options) { + super(name, src, options); } @SuppressWarnings("unchecked") @@ -38,6 +41,9 @@ protected void registerPredictor(Predictor predictor) { input.addProperty("handler", "register_adapter"); input.addProperty("name", name); input.addProperty("src", src); + for (Map.Entry entry : options.entrySet()) { + input.add(entry.getKey(), entry.getValue()); + } try { p.predict(input); } catch (TranslateException e) {