Skip to content

Commit

Permalink
Creates Auto Increment ID for worker threads and models
Browse files Browse the repository at this point in the history
This creates a new auto-increment ID to differentiate models with the same name
and adds the ID to the toString and logging.
  • Loading branch information
zachgk committed Jan 24, 2024
1 parent 24b8de4 commit 73885ac
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public DescribeWorkflowResponse(ai.djl.serving.workflow.Workflow workflow) {
targetWorker += group.getMinWorkers();

for (WorkerThread<Input, Output> worker : workers) {
int workerId = worker.getWorkerId();
String workerId = worker.getWorkerId();
long startTime = worker.getStartTime();
boolean isRunning = worker.isRunning();
g.addWorker(workerId, startTime, isRunning);
Expand Down Expand Up @@ -369,7 +369,7 @@ public int getMaxWorkers() {
* @param startTime the worker's start time
* @param isRunning {@code true} if worker is running
*/
public void addWorker(int id, long startTime, boolean isRunning) {
public void addWorker(String id, long startTime, boolean isRunning) {
Worker worker = new Worker();
worker.setId(id);
worker.setStartTime(new Date(startTime));
Expand All @@ -390,7 +390,7 @@ public List<Worker> getWorkers() {
/** A class that holds workers information. */
public static final class Worker {

private int id;
private String id;
private Date startTime;
private String status;

Expand All @@ -399,7 +399,7 @@ public static final class Worker {
*
* @return the worker's ID
*/
public int getId() {
public String getId() {
return id;
}

Expand All @@ -408,7 +408,7 @@ public int getId() {
*
* @param id the workers ID
*/
public void setId(int id) {
public void setId(String id) {
this.id = id;
}

Expand Down
6 changes: 4 additions & 2 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
Expand Down Expand Up @@ -684,7 +685,7 @@ private void testDescribeModel(Channel channel) throws InterruptedException {
assertTrue(workers.size() > 1);

DescribeWorkflowResponse.Worker worker = workers.get(0);
assertTrue(worker.getId() > 0);
assertNotEquals(worker.getId(), "");
assertNotNull(worker.getStartTime());
assertNotNull(worker.getStatus());

Expand Down Expand Up @@ -1226,7 +1227,8 @@ private void testServiceUnavailable() throws InterruptedException {
if (!System.getProperty("os.name").startsWith("Win")) {
ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class);
assertEquals(resp.getCode(), HttpResponseStatus.SERVICE_UNAVAILABLE.code());
assertEquals(resp.getMessage(), "All model workers has been shutdown: mlp_2 (READY)");
assertTrue(resp.getMessage().contains("All model workers has been shutdown"));
assertTrue(resp.getMessage().contains("mlp_2"));
}
}

Expand Down
40 changes: 24 additions & 16 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ public final class ModelInfo<I, O> extends WorkerPoolConfig<I, O> {
private transient Engine engine;
private transient boolean initialize;

private ModelInfo() {}

/**
* Constructs a new {@code ModelInfo} instance.
*
Expand Down Expand Up @@ -233,10 +235,10 @@ public void load(Device device) throws ModelException, IOException {
builder.optArgument("batchifier", "stack");
}
}
logger.info("Loading model {} on {}", id, device);
logger.info("Loading model {} {} on {}", id, uid, device);
if ("nc".equals(device.getDeviceType()) && "PyTorch".equals(engineName)) {
// assume neuron only support PyTorch
logger.info("Bypass NC core allocation");
logger.info("{}: Bypass NC core allocation", uid);
} else {
builder.optDevice(device);
}
Expand Down Expand Up @@ -339,7 +341,10 @@ public Status getStatus() {
String def = Utils.getenv("SERVING_RETRY_THRESHOLD", "10");
int threshold = Integer.parseInt(m.getProperty("retry_threshold", def));
if (failures > threshold) {
logger.info("exceed retry threshold: {}, mark model as failed.", threshold);
logger.info(
"{}: exceed retry threshold: {}, mark model as failed.",
uid,
threshold);
return Status.FAILED;
}
}
Expand Down Expand Up @@ -550,7 +555,7 @@ public Adapter getAdapter(String name) {
@Override
public void close() {
if (!getModels().isEmpty() && !Boolean.getBoolean("ai.djl.serving.keep_cache")) {
logger.info("Unloading model: {}{}", id, version == null ? "" : '/' + version);
logger.info("Unloading model: {}", this);
if (downloadDir != null) {
Utils.deleteQuietly(downloadDir);
}
Expand Down Expand Up @@ -653,7 +658,7 @@ private String inferEngine() throws ModelException {
return Engine.getDefaultEngineName();
}
} catch (IOException e) {
logger.warn("Failed search parameter files in folder: " + modelDir, e);
logger.warn(uid + ": Failed search parameter files in folder: {}", modelDir, e);
}
}
throw new ModelNotFoundException("Failed to detect engine of the model: " + modelDir);
Expand Down Expand Up @@ -688,7 +693,7 @@ private void loadServingProperties() throws ModelException {
try (InputStream is = Files.newInputStream(file)) {
prop.load(is);
} catch (IOException e) {
logger.warn("Failed read serving.properties file", e);
logger.warn(uid + ": Failed read serving.properties file", e);
}
}
configPerModelSettings();
Expand Down Expand Up @@ -749,9 +754,10 @@ private void configPerModelSettings() throws ModelException {
}

logger.info(
"Apply per model settings:\n\tjob_queue_size: {}\n\tbatch_size: {}"
"{}: Apply per model settings:\n\tjob_queue_size: {}\n\tbatch_size: {}"
+ "\n\tmax_batch_delay: {}\n\tmax_idle_time: {}\n\tload_on_devices: {}"
+ "\n\tengine: {}\n\tmpi_mode: {}\n\toption.entryPoint: {}{}",
uid,
queueSize,
batchSize,
maxBatchDelayMillis,
Expand Down Expand Up @@ -792,7 +798,7 @@ void checkAvailableMemory(Device device) throws IOException {
// 1. handle LMI use case in future
// 2. if huggingface model_id is specified, the model is downloaded
// in the python process, current file size based estimation doesn't work
logger.warn("No reserved_memory_mb defined, estimating memory usage ...");
logger.warn("{}: No reserved_memory_mb defined, estimating memory usage ...", uid);
try (Stream<Path> walk = Files.walk(modelDir)) {
requiredMemory = walk.mapToLong(ModelInfo::getFileSize).sum();
}
Expand All @@ -807,7 +813,8 @@ void checkAvailableMemory(Device device) throws IOException {
// Assume requires the same amount of CPU memory when load on GPU
long free = getAvailableCpuMemory();
logger.info(
"Available CPU memory: {} MB, required: {} MB, reserved: {} MB",
"{}: Available CPU memory: {} MB, required: {} MB, reserved: {} MB",
uid,
free / 1024 / 1024,
requiredMemory / 1024 / 1024,
reservedMemory / 1024 / 1024);
Expand All @@ -827,7 +834,8 @@ void checkAvailableMemory(Device device) throws IOException {
requiredMemory = gpuMem;
}
logger.info(
"Available GPU memory: {} MB, required: {} MB, reserved: {} MB",
"{}: Available GPU memory: {} MB, required: {} MB, reserved: {} MB",
uid,
free / 1024 / 1024,
requiredMemory / 1024 / 1024,
reservedMemory / 1024 / 1024);
Expand Down Expand Up @@ -940,9 +948,9 @@ private long getAvailableCpuMemory() {
return Long.parseLong(m.group(1)) * 1024;
}
}
logger.warn("Failed to read free memory from /proc/meminfo");
logger.warn("{}: Failed to read free memory from /proc/meminfo", uid);
} catch (IOException e) {
logger.warn("Failed open /proc/meminfo file", e);
logger.warn(uid + ": Failed open /proc/meminfo file", e);
}
}
return Integer.MAX_VALUE * 1024L;
Expand All @@ -962,28 +970,28 @@ void downloadS3() throws ModelException, IOException {
return;
}
if (modelId.startsWith("s3://")) {
logger.info("S3 url found, start downloading from {}", modelId);
logger.info("{}: S3 url found, start downloading from {}", uid, modelId);
// Use fixed download path to avoid repeat download
String hash = Utils.hash(modelId);
String download = Utils.getenv("SERVING_DOWNLOAD_DIR", null);
Path parent = download == null ? Utils.getCacheDir() : Paths.get(download);
parent = parent.resolve("download");
this.downloadDir = parent.resolve(hash);
if (Files.exists(this.downloadDir)) {
logger.info("artifacts has been downloaded already: {}", this.downloadDir);
logger.info("{}: artifacts has been downloaded already: {}", uid, this.downloadDir);
return;
}
Files.createDirectories(parent);
Path tmp = Files.createTempDirectory(parent, "tmp");
try {
downloadS3(modelId, tmp.toAbsolutePath().toString());
Utils.moveQuietly(tmp, this.downloadDir);
logger.info("Download completed! Files saved to {}", this.downloadDir);
logger.info("{}: Download completed! Files saved to {}", uid, this.downloadDir);
} finally {
Utils.deleteQuietly(tmp);
}
} else if (modelId.startsWith("djl://")) {
logger.info("djl model zoo url found: {}", modelId);
logger.info("{}: djl model zoo url found: {}", uid, modelId);
modelUrl = modelId;
// download real model from model zoo
downloadModel();
Expand Down
34 changes: 0 additions & 34 deletions wlm/src/main/java/ai/djl/serving/wlm/WorkerIdGenerator.java

This file was deleted.

29 changes: 23 additions & 6 deletions wlm/src/main/java/ai/djl/serving/wlm/WorkerPoolConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.wlm.util.AutoIncIdGenerator;
import ai.djl.translate.TranslateException;

import java.io.IOException;
Expand All @@ -34,7 +35,10 @@
*/
public abstract class WorkerPoolConfig<I, O> {

private static final AutoIncIdGenerator ID_GEN = new AutoIncIdGenerator("M-");

protected transient String id;
protected transient String uid;
protected String version;
protected String modelUrl;
protected int queueSize;
Expand All @@ -44,6 +48,10 @@ public abstract class WorkerPoolConfig<I, O> {
protected Integer minWorkers; // Integer so it becomes null when parsed from JSON
protected Integer maxWorkers; // Integer so it becomes null when parsed from JSON

protected WorkerPoolConfig() {
uid = ID_GEN.generate();
}

/**
* Loads the worker type to the specified device.
*
Expand Down Expand Up @@ -104,18 +112,27 @@ public Device withDefaultDevice(String deviceName) {
public abstract String[] getLoadOnDevices();

/**
* Sets the worker type ID.
* Sets the worker configs ID.
*
* @param id the worker type ID
* @param id the worker configs ID
*/
public void setId(String id) {
this.id = id;
}

/**
* Returns the worker type ID.
* Returns the worker configs unique ID.
*
* @return the worker configs unique ID
*/
public String getUid() {
return uid;
}

/**
* Returns the worker configs ID.
*
* @return the worker type ID
* @return the worker configs ID
*/
public String getId() {
return id;
Expand Down Expand Up @@ -308,9 +325,9 @@ public int hashCode() {
@Override
public String toString() {
if (version != null) {
return id + ':' + version + " (" + getStatus() + ')';
return id + ':' + version + " (" + uid + ", " + getStatus() + ')';
}
return id + " (" + getStatus() + ')';
return id + " (" + uid + ", " + getStatus() + ')';
}

/** An enum represents state of a worker type. */
Expand Down
Loading

0 comments on commit 73885ac

Please sign in to comment.