From a18f738df12f4d7d261ce82c7d92a7685bf9fe13 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 17 Apr 2023 15:37:22 -0700 Subject: [PATCH] [serving] Read x-synchronus and x-starting-token from input payload --- .../djl/serving/ddbcache/DdbCacheEngine.java | 33 ++++++++----------- .../serving/ddbcache/DdbCacheEngineTest.java | 6 ++-- serving/docker/Dockerfile | 2 +- .../djl/serving/cache/MemoryCacheEngine.java | 1 + .../serving/http/InferenceRequestHandler.java | 6 +++- .../ai/djl/serving/http/RequestParser.java | 10 +++++- .../java/ai/djl/serving/util/NettyUtils.java | 10 +++--- 7 files changed, 38 insertions(+), 30 deletions(-) diff --git a/plugins/cache/src/main/java/ai/djl/serving/ddbcache/DdbCacheEngine.java b/plugins/cache/src/main/java/ai/djl/serving/ddbcache/DdbCacheEngine.java index 13e97eb03..98b8c071f 100644 --- a/plugins/cache/src/main/java/ai/djl/serving/ddbcache/DdbCacheEngine.java +++ b/plugins/cache/src/main/java/ai/djl/serving/ddbcache/DdbCacheEngine.java @@ -16,6 +16,7 @@ import ai.djl.modality.Output; import ai.djl.ndarray.BytesSupplier; import ai.djl.serving.cache.CacheEngine; +import ai.djl.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,8 +34,6 @@ import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; import software.amazon.awssdk.services.dynamodb.model.DeleteRequest; import software.amazon.awssdk.services.dynamodb.model.DescribeTableRequest; -import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; -import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; import software.amazon.awssdk.services.dynamodb.model.KeyType; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; @@ -63,7 +62,8 @@ public final class DdbCacheEngine implements CacheEngine { private static final Logger logger = LoggerFactory.getLogger(DdbCacheEngine.class); - private static final String TABLE_NAME = "djl-serving-pagination-table"; + private static final String TABLE_NAME = + Utils.getenv("DDB_TABLE_NAME", "djl-serving-pagination-table"); private static final String CACHE_ID = "CACHE_ID"; private static final String INDEX = "INDEX_KEY"; private static final String HEADER = "HEADER"; @@ -85,7 +85,7 @@ public final class DdbCacheEngine implements CacheEngine { private DdbCacheEngine(DynamoDbClient ddbClient) { this.ddbClient = ddbClient; cacheTtl = Duration.ofMillis(30).toMillis(); - writeBatch = 5; + writeBatch = Integer.parseInt(Utils.getenv("SERVING_DDB_BATCH", "5")); } /** @@ -208,7 +208,7 @@ public CompletableFuture put(String key, Output output) { /** {@inheritDoc} */ @Override public Output get(String key, int limit) { - int start = 0; + int start = -1; if (key.length() > 36) { start = Integer.parseInt(key.substring(36)); key = key.substring(0, 36); @@ -222,30 +222,24 @@ public Output get(String key, int limit) { .tableName(TABLE_NAME) .keyConditionExpression(EXPRESSION) .expressionAttributeValues(attrValues) - .limit(limit) + .limit(limit == Integer.MAX_VALUE ? limit : limit + 1) .build(); QueryResponse response = ddbClient.query(request); if (response.count() == 0) { - if (start == 0) { - Map map = new ConcurrentHashMap<>(2); - map.put(CACHE_ID, AttributeValue.builder().s(key).build()); - map.put(INDEX, AttributeValue.builder().n("-1").build()); - GetItemRequest get = - GetItemRequest.builder().tableName(TABLE_NAME).key(map).build(); - GetItemResponse resp = ddbClient.getItem(get); - if (resp.hasItem()) { - AttributeValue header = resp.item().get(HEADER); - return decode(header); - } - } return null; } Output output = new Output(); boolean complete = false; + boolean first = true; List list = new ArrayList<>(); for (Map item : response.items()) { + // skip first one + if (first) { + first = false; + continue; + } AttributeValue header = item.get(HEADER); if (header != null) { Output o = decode(header); @@ -261,13 +255,14 @@ public Output get(String key, int limit) { if (lastContent != null) { complete = true; } - start++; + start = Integer.parseInt(item.get(INDEX).n()); } if (!list.isEmpty()) { output.add(join(list)); } if (!complete) { output.addProperty("x-next-token", key + start); + output.addProperty("X-Amzn-SageMaker-Custom-Attributes", "x-next-token=" + key + start); } return output; } diff --git a/plugins/cache/src/test/java/ai/djl/serving/ddbcache/DdbCacheEngineTest.java b/plugins/cache/src/test/java/ai/djl/serving/ddbcache/DdbCacheEngineTest.java index 50bc6a6a0..9632579bc 100644 --- a/plugins/cache/src/test/java/ai/djl/serving/ddbcache/DdbCacheEngineTest.java +++ b/plugins/cache/src/test/java/ai/djl/serving/ddbcache/DdbCacheEngineTest.java @@ -57,14 +57,14 @@ public void testDdbCacheEngine() throws InterruptedException, ExecutionException // query before model generate output o = engine.get(key1, Integer.MAX_VALUE); - Assert.assertEquals(o.getCode(), 202); + Assert.assertEquals(o.getCode(), 200); Assert.assertNull(o.getData()); String nextToken = o.getProperty("x-next-token", null); - Assert.assertEquals(nextToken, key1); + Assert.assertEquals(nextToken, key1 + "-1"); // retry before model generate output o = engine.get(nextToken, Integer.MAX_VALUE); - Assert.assertEquals(o.getCode(), 202); + Assert.assertEquals(o.getCode(), 200); // real output from model Output output1 = new Output(); diff --git a/serving/docker/Dockerfile b/serving/docker/Dockerfile index 147f6519a..9054732f5 100644 --- a/serving/docker/Dockerfile +++ b/serving/docker/Dockerfile @@ -61,7 +61,7 @@ RUN scripts/install_python.sh && \ echo "${djl_version} cpufull" > /opt/djl/bin/telemetry && \ djl-serving -i ai.djl.mxnet:mxnet-native-mkl:1.9.1:linux-x86_64 && \ djl-serving -i ai.djl.pytorch:pytorch-native-cpu:$torch_version:linux-x86_64 && \ - djl-serving -i ai.djl.tensorflow:tensorflow-native-cpu:2.7.4:linux-x86_64 && \ + djl-serving -i ai.djl.tensorflow:tensorflow-native-cpu:2.10.1:linux-x86_64 && \ scripts/patch_oss_dlc.sh python && \ rm -rf /opt/djl/logs && \ chown -R djl:djl /opt/djl && \ diff --git a/serving/src/main/java/ai/djl/serving/cache/MemoryCacheEngine.java b/serving/src/main/java/ai/djl/serving/cache/MemoryCacheEngine.java index 638ad410c..5ef72ebea 100644 --- a/serving/src/main/java/ai/djl/serving/cache/MemoryCacheEngine.java +++ b/serving/src/main/java/ai/djl/serving/cache/MemoryCacheEngine.java @@ -126,6 +126,7 @@ public Output get(String key, int limit) { } if (cbs.hasNext()) { o.addProperty("x-next-token", key); + o.addProperty("X-Amzn-SageMaker-Custom-Attributes", "x-next-token=" + key); } else { // clean up cache remove(key); diff --git a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java index f449aba66..0c0e3a109 100644 --- a/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java +++ b/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java @@ -71,6 +71,7 @@ public class InferenceRequestHandler extends HttpRequestHandler { private static final String X_STARTING_TOKEN = "x-starting-token"; private static final String X_NEXT_TOKEN = "x-next-token"; private static final String X_MAX_ITEMS = "x-max-items"; + private static final String X_CUSTOM_ATTRIBUTES = "X-Amzn-SageMaker-Custom-Attributes"; private RequestParser requestParser; @@ -292,11 +293,13 @@ void runJob( pending.setMessage("The model result is not yet available"); pending.setCode(202); pending.addProperty(X_NEXT_TOKEN, nextToken); + pending.addProperty(X_CUSTOM_ATTRIBUTES, X_NEXT_TOKEN + '=' + nextToken); cache.put(nextToken, pending); // Send back token to user Output out = new Output(); out.addProperty(X_NEXT_TOKEN, nextToken); + out.addProperty(X_CUSTOM_ATTRIBUTES, X_NEXT_TOKEN + '=' + nextToken); sendOutput(out, ctx); // Run model @@ -325,7 +328,7 @@ private void getCacheResult(ChannelHandlerContext ctx, Input input, String start CacheEngine cache = CacheManager.getCacheEngine(); Output output = cache.get(startingToken, limit); if (output == null) { - throw new BadRequestException("Invalid " + X_STARTING_TOKEN); + throw new BadRequestException("Invalid " + X_STARTING_TOKEN + ": " + startingToken); } sendOutput(output, ctx); } @@ -392,6 +395,7 @@ void sendOutput(Output output, ChannelHandlerContext ctx) { void onException(Throwable t, ChannelHandlerContext ctx) { HttpResponseStatus status; if (t instanceof TranslateException || t instanceof BadRequestException) { + logger.debug(t.getMessage(), t); SERVER_METRIC.info("{}", RESPONSE_4_XX); status = HttpResponseStatus.BAD_REQUEST; } else if (t instanceof WlmException) { diff --git a/serving/src/main/java/ai/djl/serving/http/RequestParser.java b/serving/src/main/java/ai/djl/serving/http/RequestParser.java index 90b4ddde0..595fdfdc0 100644 --- a/serving/src/main/java/ai/djl/serving/http/RequestParser.java +++ b/serving/src/main/java/ai/djl/serving/http/RequestParser.java @@ -60,7 +60,15 @@ public Input parseRequest(FullHttpRequest req, QueryStringDecoder decoder) { for (Map.Entry entry : req.headers().entries()) { String key = entry.getKey(); - if (!HttpHeaderNames.CONTENT_TYPE.contentEqualsIgnoreCase(key)) { + if ("X-Amzn-SageMaker-Custom-Attributes".equalsIgnoreCase(key)) { + String[] tokens = entry.getValue().split(";"); + for (String token : tokens) { + String[] pair = token.split("=", 2); + if (pair.length == 2) { + input.addProperty(pair[0].trim(), pair[1].trim()); + } + } + } else if (!HttpHeaderNames.CONTENT_TYPE.contentEqualsIgnoreCase(key)) { input.addProperty(key, entry.getValue()); } } diff --git a/serving/src/main/java/ai/djl/serving/util/NettyUtils.java b/serving/src/main/java/ai/djl/serving/util/NettyUtils.java index 53e621bc7..436ce17d9 100644 --- a/serving/src/main/java/ai/djl/serving/util/NettyUtils.java +++ b/serving/src/main/java/ai/djl/serving/util/NettyUtils.java @@ -242,10 +242,10 @@ public static void sendFile( */ public static void sendError(ChannelHandlerContext ctx, Throwable t) { if (t instanceof ResourceNotFoundException || t instanceof ModelNotFoundException) { - logger.trace("", t); + logger.debug("", t); NettyUtils.sendError(ctx, HttpResponseStatus.NOT_FOUND, t); } else if (t instanceof BadRequestException) { - logger.trace("", t); + logger.debug("", t); BadRequestException e = (BadRequestException) t; HttpResponseStatus status = HttpResponseStatus.valueOf(e.getCode(), e.getMessage()); NettyUtils.sendError(ctx, status, t); @@ -253,13 +253,13 @@ public static void sendError(ChannelHandlerContext ctx, Throwable t) { logger.warn("", t); NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, t); } else if (t instanceof ModelException) { - logger.trace("", t); + logger.debug("", t); NettyUtils.sendError(ctx, HttpResponseStatus.BAD_REQUEST, t); } else if (t instanceof MethodNotAllowedException) { - logger.trace("", t); + logger.debug("", t); NettyUtils.sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, t); } else if (t instanceof ServiceUnavailableException || t instanceof WlmException) { - logger.trace("", t); + logger.warn("", t); NettyUtils.sendError(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, t); } else { logger.error("", t);