Skip to content

Commit

Permalink
Adds adapter registration options (#1634)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored Mar 20, 2024
1 parent abdb7ad commit 2bf6bae
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 21 deletions.
14 changes: 9 additions & 5 deletions engines/python/src/test/resources/adaptecho/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def register_adapter(inputs: Input):
global adapters
name = inputs.get_properties()["name"]
adapters[name] = True
adapters[name] = inputs
return Output().add("Successfully registered adapter")


Expand All @@ -51,13 +51,17 @@ def handle(inputs: Input):
for i, input in enumerate(inputs.get_batches()):
data = input.get_as_string()
if input.contains_key("adapter"):
adapter = input.get_as_string("adapter")
if adapter in adapters:
adapter_name = input.get_as_string("adapter")
if adapter_name in adapters:
adapter = adapters[adapter_name]
option = ""
if adapter.contains_key("echooption"):
option = adapter.get_as_string("echooption")
# Registered adapter
out = adapter + data
out = adapter_name + option + data
else:
# Dynamic adapter
out = "dyn" + adapter + data
out = "dyn" + adapter_name + data
else:
out = data
outputs.add(out, key="data", batch_index=i)
Expand Down
2 changes: 2 additions & 0 deletions serving/docs/adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ For the simple model + adapter case, you can also directly use the adapter [work
With our workflows, multiple workflows sharing models will be de-duplicated.
So, the effect of having multiple adapters can be easily made with having one workflow for each adapter.
This system can be used on [Amazon SageMaker Multi-Model Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html).
More details can be found on the [AdapterWorkflowFunction docs](https://javadoc.io/doc/ai.djl.serving/serving/latest/ai/djl/serving/workflow/function/AdapterWorkflowFunction.html).

```
workflow.json:
Expand Down Expand Up @@ -165,6 +166,7 @@ These can then take the adapters and save the src, pre-download it, cache it in
def register_adapter(inputs: Input):
name = inputs.get_properties()["name"]
src = inputs.get_properties()["src"]
options = inputs.get_content()
# Do adapter registration tasks
return Output().add("Successfully registered adapter")

Expand Down
1 change: 1 addition & 0 deletions serving/docs/adapters_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ This is an extension of the [Management API](management_api.md) and can be acces

* name - The adapter name.
* src - The adapter src. It currently requires a file, but eventually an id or URL can be supported depending on the model handler.
* All additional arguments will be treated as additional model-specific options and will be passed to the model during adapter registration

```bash
curl -X POST "http://localhost:8080/models/adaptecho/adapters?name=a1&src=..."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;

/** A class handling inbound HTTP requests to the management API for adapters. */
Expand Down Expand Up @@ -154,7 +156,15 @@ private void handleRegisterAdapter(
if (wp == null) {
throw new BadRequestException("The model " + modelName + " was not found");
}
Adapter adapter = Adapter.newInstance(wp.getWpc(), adapterName, src);

Map<String, String> options = new ConcurrentHashMap<>();
for (Map.Entry<String, List<String>> entry : decoder.parameters().entrySet()) {
if (entry.getValue().size() == 1) {
options.put(entry.getKey(), entry.getValue().get(0));
}
}

Adapter adapter = Adapter.newInstance(wp.getWpc(), adapterName, src, options);
adapter.register(wp);

String msg = "Adapter " + adapterName + " registered";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
*
* <p>To use this workflow function, you must pre-specify the adapted functions in the configs. In
* the configs, create an object "adapters" with keys as reference names and values as objects. The
* adapter reference objects should have three properties:
* adapter reference objects should have the following properties:
*
* <ul>
* <li>model - the model name
* <li>name - the adapter name
* <li>url - the adapter url
* <li>options (optional) - an object containing additional string options
* </ul>
*
* <p>To call this workflow function, it requires two arguments. The first is the adapter config
Expand Down Expand Up @@ -69,8 +70,18 @@ public void prepare(WorkLoadManager wlm, Map<String, Map<String, Object>> config
String adapterName = entry.getKey();
String src = (String) config.get("src");

Map<String, String> options = new ConcurrentHashMap<>();
if (config.containsKey("options") && config.get("options") instanceof Map) {
for (Map.Entry<String, Object> option :
((Map<String, Object>) config.get("options")).entrySet()) {
if (option.getValue() instanceof String) {
options.put(option.getKey(), (String) option.getValue());
}
}
}

WorkerPool<?, ?> wp = wlm.getWorkerPoolById(modelName);
Adapter adapter = Adapter.newInstance(wp.getWpc(), adapterName, src);
Adapter adapter = Adapter.newInstance(wp.getWpc(), adapterName, src, options);
adapters.put(adapterName, new AdapterReference(modelName, adapter));
}
}
Expand Down
12 changes: 6 additions & 6 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ public void testAdapterWorkflows()
assertTrue(server.isRunning());
Channel channel = initTestChannel();

testAdapterWorkflowPredict(channel, "adapter1", "a1");
testAdapterWorkflowPredict(channel, "adapter1", "a1weo");
testAdapterWorkflowPredict(channel, "adapter2", "a2");
testRegisterAdapterWorkflowTemplate(channel);

Expand Down Expand Up @@ -868,7 +868,7 @@ private void testAdapterRegister(Channel channel, boolean registerModel, boolean
testAdapterMissing();

String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
url = strModelPrefix + "/adapters?name=" + "adaptable" + "&src=" + "src";
url = strModelPrefix + "/adapters?name=adaptable&src=src&echooption=opt";
request(channel, HttpMethod.POST, url);
assertHttpOk();
}
Expand Down Expand Up @@ -926,7 +926,7 @@ private void testAdapterPredict(Channel channel) throws InterruptedException {
req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN);
request(channel, req);
assertHttpOk();
assertEquals(result, "adaptabletestPredictAdapter");
assertEquals(result, "adaptableopttestPredictAdapter");
}

private void testAdapterDirPredict(Channel channel) throws InterruptedException {
Expand All @@ -942,7 +942,7 @@ private void testAdapterDirPredict(Channel channel) throws InterruptedException
assertEquals(result, "myBuiltinAdaptertestPredictBuiltinAdapter");
}

private void testAdapterWorkflowPredict(Channel channel, String workflow, String adapter)
private void testAdapterWorkflowPredict(Channel channel, String workflow, String prefix)
throws InterruptedException {
logTestFunction();
String url = "/predictions/" + workflow;
Expand All @@ -953,7 +953,7 @@ private void testAdapterWorkflowPredict(Channel channel, String workflow, String
req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN);
request(channel, req);
assertHttpOk();
assertEquals(result, adapter + "testAWP");
assertEquals(result, prefix + "testAWP");
}

private void testRegisterAdapterWorkflowTemplate(Channel channel) throws InterruptedException {
Expand All @@ -980,7 +980,7 @@ private void testAdapterInvoke(Channel channel) throws InterruptedException {
request(channel, req);

assertHttpOk();
assertEquals(result, "adaptabletestInvokeAdapter");
assertEquals(result, "adaptableopttestInvokeAdapter");
}

private void testAdapterList(Channel channel, boolean modelPrefix) throws InterruptedException {
Expand Down
5 changes: 4 additions & 1 deletion serving/src/test/resources/adapterWorkflows/w1/workflow.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
"adapters": {
"a1": {
"model": "m",
"src": "url1"
"src": "url1",
"options": {
"echooption": "weo"
}
}
}
},
Expand Down
12 changes: 9 additions & 3 deletions wlm/src/main/java/ai/djl/serving/wlm/Adapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

/**
Expand All @@ -27,16 +28,19 @@ public abstract class Adapter {

protected String name;
protected String src;
protected Map<String, String> options;

/**
* Constructs an {@link Adapter}.
*
* @param name the adapter name
* @param src the adapter source
* @param options additional adapter options
*/
protected Adapter(String name, String src) {
protected Adapter(String name, String src, Map<String, String> options) {
this.name = name;
this.src = src;
this.options = options;
}

/**
Expand All @@ -48,9 +52,11 @@ protected Adapter(String name, String src) {
* @param wpc the worker pool config for the new adapter
* @param name the adapter name
* @param src the adapter source
* @param options additional adapter options
* @return the new adapter
*/
public static Adapter newInstance(WorkerPoolConfig<?, ?> wpc, String name, String src) {
public static Adapter newInstance(
WorkerPoolConfig<?, ?> wpc, String name, String src, Map<String, String> options) {
if (!(wpc instanceof ModelInfo)) {
String modelName = wpc.getId();
throw new IllegalArgumentException("The worker " + modelName + " is not a model");
Expand All @@ -69,7 +75,7 @@ public static Adapter newInstance(WorkerPoolConfig<?, ?> wpc, String name, Strin
ModelInfo<?, ?> modelInfo = (ModelInfo<?, ?>) wpc;
// TODO Replace usage of class name with creating adapters by Engine.newPatch(name ,src)
if ("PyEngine".equals(modelInfo.getEngine().getClass().getSimpleName())) {
return new PyAdapter(name, src);
return new PyAdapter(name, src, options);
} else {
throw new IllegalArgumentException(
"Adapters are only currently supported for Python models");
Expand Down
4 changes: 3 additions & 1 deletion wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -282,7 +283,8 @@ public void load(Device device) throws ModelException, IOException {
Adapter.newInstance(
this,
adapterName,
adapterDir.toAbsolutePath().toString());
adapterDir.toAbsolutePath().toString(),
Collections.emptyMap());
registerAdapter(adapter);
long d = (System.nanoTime() - start) / 1000;
Metric me =
Expand Down
10 changes: 8 additions & 2 deletions wlm/src/main/java/ai/djl/serving/wlm/PyAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;

import java.util.Map;

/** An overload of {@link Adapter} for the python engine. */
public class PyAdapter extends Adapter {

Expand All @@ -25,9 +27,10 @@ public class PyAdapter extends Adapter {
*
* @param name the adapter name
* @param src the adapter src
* @param options additional adapter options
*/
protected PyAdapter(String name, String src) {
super(name, src);
protected PyAdapter(String name, String src, Map<String, String> options) {
super(name, src, options);
}

@SuppressWarnings("unchecked")
Expand All @@ -38,6 +41,9 @@ protected void registerPredictor(Predictor<?, ?> predictor) {
input.addProperty("handler", "register_adapter");
input.addProperty("name", name);
input.addProperty("src", src);
for (Map.Entry<String, String> entry : options.entrySet()) {
input.add(entry.getKey(), entry.getValue());
}
try {
p.predict(input);
} catch (TranslateException e) {
Expand Down

0 comments on commit 2bf6bae

Please sign in to comment.