Skip to content

Commit

Permalink
[serving] Emits medel inference metrics to log file
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Feb 3, 2023
1 parent 0ae4932 commit 84c6590
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 3 deletions.
1 change: 0 additions & 1 deletion serving/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ run {
systemProperties System.getProperties()
systemProperties.remove("user.dir")
systemProperty("file.encoding", "UTF-8")
systemProperty("ai.djl.pytorch.num_interop_threads", "1")
// systemProperty("ai.djl.logging.level", "debug")
systemProperty("log4j.configurationFile", "${project.projectDir}/src/main/conf/log4j2.xml")
applicationDefaultJvmArgs = ["-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=4000"]
Expand Down
11 changes: 10 additions & 1 deletion serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ private void testKServeV2Infer(Channel channel) throws InterruptedException {
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1,
HttpMethod.POST,
"/models?model_name=identity&url="
"/models?model_name=identity&min_worker=1&url="
+ URLEncoder.encode(url, StandardCharsets.UTF_8)));
assertEquals(httpStatus.code(), HttpResponseStatus.OK.code());

Expand All @@ -1071,12 +1071,21 @@ private void testKServeV2Infer(Channel channel) throws InterruptedException {
data.put("inputs", new Object[] {input});
data.put("outputs", new Object[] {output});

// trigger model metrics logging
req.content().writeCharSequence(JsonUtils.GSON.toJson(data), StandardCharsets.UTF_8);
HttpUtil.setContentLength(req, req.content().readableBytes());
req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
request(channel, req);
assertEquals(httpStatus.code(), HttpResponseStatus.OK.code());

req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/v2/models/identity/infer");
req.content().writeCharSequence(JsonUtils.GSON.toJson(data), StandardCharsets.UTF_8);
HttpUtil.setContentLength(req, req.content().readableBytes());
req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
request(channel, req);

request(
channel,
new DefaultFullHttpRequest(
Expand Down
2 changes: 2 additions & 0 deletions serving/src/test/resources/identity/serving.properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ gpu.maxWorkers=3
cpu.minWorkers=2
cpu.maxWorkers=4
job_queue_size=10
log_model_metric=true
metrics_aggregation=1
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
*/
public class PermanentBatchAggregator<I, O> extends BatchAggregator<I, O> {

private static final Logger logger = LoggerFactory.getLogger(TemporaryBatchAggregator.class);
private static final Logger logger = LoggerFactory.getLogger(PermanentBatchAggregator.class);

/**
* Constructs a {@code PermanentBatchAggregator} instance.
Expand Down
20 changes: 20 additions & 0 deletions wlm/src/main/java/ai/djl/serving/wlm/WorkerThread.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.Device;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.serving.wlm.util.WlmException;
import ai.djl.serving.wlm.util.WorkerJob;
Expand All @@ -31,8 +32,10 @@
public final class WorkerThread<I, O> implements Runnable {

private static final Logger logger = LoggerFactory.getLogger(WorkerThread.class);
private static final Logger MODEL_METRIC = LoggerFactory.getLogger("model_metric");

private String workerName;
private String modelName;
private Predictor<I, O> predictor;

private AtomicBoolean running = new AtomicBoolean(true);
Expand All @@ -44,6 +47,8 @@ public final class WorkerThread<I, O> implements Runnable {
private int workerId;
private long startTime;
private boolean fixPoolThread;
private boolean logModelMetric;
private int metricsAggregation;

/**
* Builds a workerThread with this builder.
Expand All @@ -60,6 +65,14 @@ private WorkerThread(Builder<I, O> builder) {
ZooModel<I, O> model = builder.model.getModel(device);

predictor = model.newPredictor();
modelName = model.getName();
logModelMetric = Boolean.parseBoolean(model.getProperty("log_model_metric"));
String value = model.getProperty("metrics_aggregation");
if (value == null || value.isBlank()) {
metricsAggregation = 1000;
} else {
metricsAggregation = Integer.parseInt(value);
}
}

/** {@inheritDoc} */
Expand All @@ -72,6 +85,13 @@ public void run() {
List<I> req = null;
String errorMessage = "Worker shutting down";
try {
if (logModelMetric) {
Metrics metrics = new Metrics();
metrics.setLimit(metricsAggregation);
metrics.setOnLimit(
(m, s) -> MODEL_METRIC.info("{}-{}", modelName, m.percentile(s, 50)));
predictor.setMetrics(metrics);
}
while (isRunning() && !aggregator.isFinished()) {
req = aggregator.getRequest();
if (req != null && !req.isEmpty()) {
Expand Down

0 comments on commit 84c6590

Please sign in to comment.