Skip to content

Commit

Permalink
[serving] Adds chunked encoding support (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Mar 13, 2023
1 parent 502e05e commit 8b22211
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 45 deletions.
65 changes: 54 additions & 11 deletions engines/python/setup/djl_python/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, code=200, message='OK'):
self.message = message
self.properties = dict()
self.content = PairList()
self.stream_content = None

def __str__(self):
d = dict()
Expand Down Expand Up @@ -121,6 +122,9 @@ def add_as_json(self, val, key=None, batch_index=None):
key=key,
batch_index=batch_index)

def add_stream_content(self, stream_content):
self.stream_content = stream_content

@staticmethod
def _encode_json(val) -> bytes:
return bytearray(
Expand All @@ -140,7 +144,7 @@ def write_utf8(msg, val):
msg += struct.pack('>h', len(buf))
msg += buf

def encode(self) -> bytearray:
def send(self, cl_socket):
msg = bytearray()
msg += struct.pack('>h', self.code)
self.write_utf8(msg, self.message)
Expand All @@ -150,13 +154,52 @@ def encode(self) -> bytearray:
self.write_utf8(msg, k)
self.write_utf8(msg, v)

size = self.content.size()
msg += struct.pack('>h', size)
for i in range(size):
k = self.content.key_at(i)
v = self.content.value_at(i)
self.write_utf8(msg, k)
msg += struct.pack('>i', len(v))
msg += v

return msg
if self.stream_content is None:
size = self.content.size()
msg += struct.pack('>h', size)
for i in range(size):
k = self.content.key_at(i)
v = self.content.value_at(i)
self.write_utf8(msg, k)
msg += struct.pack('>i', len(v))
msg += v
cl_socket.sendall(msg)
return

msg += struct.pack('>h', -1)
cl_socket.sendall(msg)

while True:
try:
data = next(self.stream_content)

msg = bytearray()
msg += b'\1'

if type(data) is str:
data = data.encode('utf-8')
elif type(data) is bytearray:
pass
elif type(data) is bytes:
data = bytearray(data)
else:
data = bytearray(self._encode_json(data))

msg += struct.pack('>i', len(data))
msg += data
cl_socket.sendall(msg)
except StopIteration:
msg = bytearray()
msg += b'\0'
msg += struct.pack('>i', 0)
cl_socket.sendall(msg)
break
except Exception as e:
logging.exception("Failed read streaming content from output")
msg = bytearray()
msg += b'\0'
data = str(e).encode('utf-8')
msg += struct.pack('>i', len(data))
msg += data
cl_socket.sendall(msg)
break
4 changes: 2 additions & 2 deletions engines/python/setup/djl_python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def run_server(self):
logging.exception("Failed invoke service.invoke_handler()")
outputs = Output().error(str(e))

if not cl_socket.sendall(outputs.encode()):
logging.debug("Outputs is sent to DJL engine.")
outputs.send(cl_socket)
logging.debug("Outputs is sent to DJL engine.")


def main():
Expand Down
38 changes: 26 additions & 12 deletions engines/python/src/main/java/ai/djl/python/engine/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.Model;
import ai.djl.engine.EngineException;
import ai.djl.modality.ChunkedBytesSupplier;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
Expand Down Expand Up @@ -353,6 +354,8 @@ protected void encode(ChannelHandlerContext ctx, Input msg, ByteBuf out) {
private static final class OutputDecoder extends ByteToMessageDecoder {

private int maxBufferSize;
private boolean hasMoreChunk;
private ChunkedBytesSupplier data;

OutputDecoder(int maxBufferSize) {
this.maxBufferSize = maxBufferSize;
Expand All @@ -366,19 +369,30 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
in.markReaderIndex();
boolean completed = false;
try {
int code = in.readShort();
String message = CodecUtils.readUtf8(in);
Output output = new Output(code, message);
int size = in.readShort();
for (int i = 0; i < size; ++i) {
output.addProperty(CodecUtils.readUtf8(in), CodecUtils.readUtf8(in));
if (hasMoreChunk) {
hasMoreChunk = in.readByte() == 1;
data.appendContent(CodecUtils.readBytes(in, maxBufferSize), !hasMoreChunk);
} else {
int code = in.readShort();
String message = CodecUtils.readUtf8(in);
Output output = new Output(code, message);
int size = in.readShort();
for (int i = 0; i < size; ++i) {
output.addProperty(CodecUtils.readUtf8(in), CodecUtils.readUtf8(in));
}
int contentSize = in.readShort();
if (contentSize == -1) {
hasMoreChunk = true;
data = new ChunkedBytesSupplier();
output.add(data);
} else {
for (int i = 0; i < contentSize; ++i) {
String key = CodecUtils.readUtf8(in);
output.add(key, CodecUtils.readBytes(in, maxBufferSize));
}
}
out.add(output);
}
int contentSize = in.readShort();
for (int i = 0; i < contentSize; ++i) {
String key = CodecUtils.readUtf8(in);
output.add(key, CodecUtils.readBytes(in, maxBufferSize));
}
out.add(output);
completed = true;
} catch (IndexOutOfBoundsException | NegativeArraySizeException ignore) {
// ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.inference.Predictor;
import ai.djl.modality.ChunkedBytesSupplier;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
Expand All @@ -40,12 +42,14 @@
import java.io.IOException;
import java.lang.reflect.Type;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

public class PyEngineTest {

Expand Down Expand Up @@ -181,6 +185,29 @@ public void testBatchEcho() throws TranslateException, IOException, ModelExcepti
}
}

@Test
public void testStreamEcho()
throws TranslateException, IOException, ModelException, InterruptedException {
Criteria<Input, Output> criteria =
Criteria.builder()
.setTypes(Input.class, Output.class)
.optModelPath(Paths.get("src/test/resources/echo"))
.optEngine("Python")
.build();
try (ZooModel<Input, Output> model = criteria.loadModel();
Predictor<Input, Output> predictor = model.newPredictor()) {
Input input = new Input();
input.add("stream", "true");
Output out = predictor.predict(input);
BytesSupplier supplier = out.getData();
Assert.assertTrue(supplier instanceof ChunkedBytesSupplier);
ChunkedBytesSupplier cbs = (ChunkedBytesSupplier) supplier;
Assert.assertTrue(cbs.hasNext());
byte[] buf = cbs.nextChunk(1, TimeUnit.MINUTES);
Assert.assertEquals(new String(buf, StandardCharsets.UTF_8), "t-0\n");
}
}

@Test
public void testResnet18() throws TranslateException, IOException, ModelException {
if (!Boolean.getBoolean("nightly")) {
Expand Down
13 changes: 12 additions & 1 deletion engines/python/src/test/resources/echo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@

import logging
import sys
import time

from djl_python import Input
from djl_python import Output


def steam_token():
for i in range(5):
time.sleep(1)
yield f"t-{i}\n"


def handle(inputs: Input):
"""
Default handler function
Expand All @@ -43,6 +51,9 @@ def handle(inputs: Input):
for i, item in enumerate(batch):
outputs.add(item.get_as_bytes(), key="data", batch_index=i)
else:
outputs.add(data, key="data")
if inputs.contains_key("stream"):
outputs.add_stream_content(steam_token())
else:
outputs.add(data, key="data")

return outputs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ private void inference(ChannelHandlerContext ctx, FullHttpRequest req, String[]

modelManager
.runJob(workflow, input)
.whenComplete(
.whenCompleteAsync(
(o, t) -> {
if (o != null) {
responseOutput(response, o, ctx, request.outputs);
Expand Down
2 changes: 0 additions & 2 deletions serving/src/main/java/ai/djl/serving/ServerInitializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.stream.ChunkedWriteHandler;

/**
* A special {@link io.netty.channel.ChannelInboundHandler} which offers an easy way to initialize a
Expand Down Expand Up @@ -65,7 +64,6 @@ public void initChannel(Channel ch) {
}
pipeline.addLast("http", new HttpServerCodec());
pipeline.addLast("aggregator", new HttpObjectAggregator(maxRequestSize, true));
pipeline.addLast(new ChunkedWriteHandler());
switch (connectorType) {
case MANAGEMENT:
pipeline.addLast(new ConfigurableHttpRequestHandler(pluginManager));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.ModelException;
import ai.djl.metric.Metric;
import ai.djl.modality.ChunkedBytesSupplier;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
Expand All @@ -26,20 +27,27 @@
import ai.djl.serving.workflow.Workflow;
import ai.djl.translate.TranslateException;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.QueryStringDecoder;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;

/** A class handling inbound HTTP requests for the management API. */
Expand All @@ -52,11 +60,11 @@ public class InferenceRequestHandler extends HttpRequestHandler {
private static final Metric RESPONSE_5_XX = new Metric("5XX", 1);
private static final Metric WLM_ERROR = new Metric("WlmError", 1);
private static final Metric SERVER_ERROR = new Metric("ServerError", 1);
private RequestParser requestParser;

private static final Pattern PATTERN =
Pattern.compile("/(ping|invocations|predictions)([/?].*)?|/models/.+/invoke");

private RequestParser requestParser;

/** default constructor. */
public InferenceRequestHandler() {
this.requestParser = new RequestParser();
Expand Down Expand Up @@ -240,7 +248,7 @@ void runJob(
ModelManager modelManager, ChannelHandlerContext ctx, Workflow workflow, Input input) {
modelManager
.runJob(workflow, input)
.whenComplete(
.whenCompleteAsync(
(o, t) -> {
if (o != null) {
sendOutput(o, ctx);
Expand All @@ -254,6 +262,16 @@ void runJob(
}

void sendOutput(Output output, ChannelHandlerContext ctx) {
/*
* We can load the models based on the configuration file. Since this Job is
* not driven by the external connections, we could have a empty context for
* this job. We shouldn't try to send a response to ctx if this is not triggered
* by external clients.
*/
if (ctx == null) {
return;
}

HttpResponseStatus status;
int code = output.getCode();
if (code == 200) {
Expand All @@ -269,25 +287,37 @@ void sendOutput(Output output, ChannelHandlerContext ctx) {
}
status = new HttpResponseStatus(code, output.getMessage());
}
BytesSupplier data = output.getData();
if (data instanceof ChunkedBytesSupplier) {
HttpResponse resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status, true);
for (Map.Entry<String, String> entry : output.getProperties().entrySet()) {
resp.headers().set(entry.getKey(), entry.getValue());
}
NettyUtils.sendHttpResponse(ctx, resp, true);
ChunkedBytesSupplier supplier = (ChunkedBytesSupplier) data;
try {
while (supplier.hasNext()) {
byte[] buf = supplier.nextChunk(1, TimeUnit.MINUTES);
ByteBuf bb = Unpooled.wrappedBuffer(buf);
ctx.writeAndFlush(new DefaultHttpContent(bb));
}
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT);
} catch (InterruptedException | IllegalStateException e) {
logger.warn("Chunk reading interrupted", e);
ctx.newFailedFuture(e);
}
return;
}

FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, true);
for (Map.Entry<String, String> entry : output.getProperties().entrySet()) {
resp.headers().set(entry.getKey(), entry.getValue());
}
BytesSupplier data = output.getData();
if (data != null) {
resp.content().writeBytes(data.getAsBytes());
}

/*
* We can load the models based on the configuration file.Since this Job is
* not driven by the external connections, we could have a empty context for
* this job. We shouldn't try to send a response to ctx if this is not triggered
* by external clients.
*/
if (ctx != null) {
NettyUtils.sendHttpResponse(ctx, resp, true);
}
NettyUtils.sendHttpResponse(ctx, resp, true);
}

void onException(Throwable t, ChannelHandlerContext ctx) {
Expand Down
Loading

0 comments on commit 8b22211

Please sign in to comment.