Skip to content

Commit

Permalink
Minor refactor of the code
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 15, 2024
1 parent 85a6ae8 commit e349a05
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 40 deletions.
21 changes: 10 additions & 11 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static void configure(
String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES");
setDynamicBatch(lmiProperties, modelConfig, modelInfo, features);
setRollingBatch(lmiProperties, modelConfig, features);
setMPIMode(lmiProperties, modelConfig, features);
setMpiMode(lmiProperties, modelConfig, features);
setTensorParallelDegree(lmiProperties);
setRollingBatchSize(lmiProperties);
}
Expand All @@ -85,15 +85,15 @@ private static void setRollingBatch(
} else if (!isTextGenerationModel(modelConfig)) {
// Non text-generation use-cases are not compatible with rolling batch
rollingBatch = "disable";
} else if (isVLLMEnabled(features) && isLmiDistEnabled(features)) {
} else if (isVllmEnabled(features) && isLmiDistEnabled(features)) {
rollingBatch = MODEL_TO_ROLLING_BATCH.getOrDefault(modelConfig.getModelType(), "auto");
} else if (LmiUtils.isTrtLLMRollingBatch(lmiProperties)) {
} else if (LmiUtils.isTrtLlmRollingBatch(lmiProperties)) {
rollingBatch = "trtllm";
}
lmiProperties.setProperty("option.rolling_batch", rollingBatch);
}

private static void setMPIMode(
private static void setMpiMode(
Properties lmiProperties,
LmiUtils.HuggingFaceModelConfig modelConfig,
String features) {
Expand All @@ -102,7 +102,7 @@ private static void setMPIMode(
lmiProperties.setProperty("option.mpi_mode", "true");
}
// TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching.
if (isT5TrtLLM(modelConfig, features)) {
if (isT5TrtLlm(modelConfig, features)) {
lmiProperties.setProperty("option.mpi_mode", "true");
}
}
Expand All @@ -124,8 +124,7 @@ private static void setDynamicBatch(
ModelInfo<?, ?> modelInfo,
String features) {
// TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching.
if (isT5TrtLLM(modelConfig, features)) {

if (isT5TrtLlm(modelConfig, features)) {
// To do runtime compilation for TensorRT-LLM T5 model.
lmiProperties.setProperty("trtllm_python_backend", String.valueOf(true));
lmiProperties.setProperty("option.rolling_batch", "disable");
Expand Down Expand Up @@ -154,21 +153,21 @@ private static void setRollingBatchSize(Properties lmiProperties) {
"option.max_rolling_batch_size", String.valueOf(rollingBatchSize));
}

private static boolean isVLLMEnabled(String features) {
private static boolean isVllmEnabled(String features) {
return features != null && features.contains("vllm");
}

private static boolean isLmiDistEnabled(String features) {
return features != null && features.contains("lmi-dist");
}

private static boolean isTrtLLMEnabled(String features) {
private static boolean isTrtLlmEnabled(String features) {
return features != null && features.contains("trtllm");
}

private static boolean isT5TrtLLM(
private static boolean isT5TrtLlm(
LmiUtils.HuggingFaceModelConfig modelConfig, String features) {
return isTrtLLMEnabled(features) && "t5".equals(modelConfig.getModelType());
return isTrtLlmEnabled(features) && "t5".equals(modelConfig.getModelType());
}

private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig modelConfig) {
Expand Down
44 changes: 19 additions & 25 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,14 @@ public final class LmiUtils {

private LmiUtils() {}

static void configureLMIModel(ModelInfo<?, ?> modelInfo) throws ModelException {
Properties prop = modelInfo.getProperties();
static void configureLmiModel(ModelInfo<?, ?> modelInfo) throws ModelException {
HuggingFaceModelConfig modelConfig = getHuggingFaceModelConfig(modelInfo);
if (modelConfig == null) {
// Not a LMI model
return;
}

Properties prop = modelInfo.getProperties();
LmiConfigRecommender.configure(modelInfo, prop, modelConfig);
logger.info(
"Detected engine: {}, rolling_batch: {}, tensor_parallel_degree {}, for modelType:"
Expand All @@ -65,12 +70,7 @@ static void configureLMIModel(ModelInfo<?, ?> modelInfo) throws ModelException {
modelConfig.getModelType());
}

static boolean isLMIModel(ModelInfo<?, ?> modelInfo) {
String modelId = modelInfo.getProperties().getProperty("option.model_id");
return null != generateHuggingFaceConfigUri(modelInfo, modelId);
}

static boolean isTrtLLMRollingBatch(Properties properties) {
static boolean isTrtLlmRollingBatch(Properties properties) {
String rollingBatch = properties.getProperty("option.rolling_batch");
if ("trtllm".equals(rollingBatch)) {
return true;
Expand All @@ -83,14 +83,9 @@ static boolean isTrtLLMRollingBatch(Properties properties) {
return false;
}

static boolean isRollingBatchEnabled(Properties properties) {
String rollingBatch = properties.getProperty("option.rolling_batch");
return null != rollingBatch && !"disable".equals(rollingBatch);
}

static boolean needConvert(ModelInfo<?, ?> info) {
Properties properties = info.getProperties();
return isTrtLLMRollingBatch(info.getProperties())
return isTrtLlmRollingBatch(info.getProperties())
|| properties.containsKey("trtllm_python_backend");
}

Expand Down Expand Up @@ -141,14 +136,17 @@ static void convertTrtLLM(ModelInfo<?, ?> info) throws IOException {
* @return the Huggingface config.json file URI
*/
public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String modelId) {
URI configUri = null;
Path modelDir = modelInfo.modelDir;
if (modelId != null && modelId.startsWith("s3://")) {
if (Files.isRegularFile(modelDir.resolve("config.json"))) {
return modelDir.resolve("config.json").toUri();
} else if (Files.isRegularFile(modelDir.resolve("model_index.json"))) {
return modelDir.resolve("model_index.json").toUri();
} else if (modelId != null && modelId.startsWith("s3://")) {
Path downloadDir = modelInfo.downloadDir;
if (Files.isRegularFile(downloadDir.resolve("config.json"))) {
configUri = downloadDir.resolve("config.json").toUri();
return downloadDir.resolve("config.json").toUri();
} else if (Files.isRegularFile(downloadDir.resolve("model_index.json"))) {
configUri = downloadDir.resolve("model_index.json").toUri();
return downloadDir.resolve("model_index.json").toUri();
}
} else if (modelId != null) {
modelInfo.prop.setProperty("option.model_id", modelId);
Expand All @@ -165,13 +163,9 @@ public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String
}
return null;
}
configUri = getHuggingFaceHubConfigUri(modelId);
} else if (Files.isRegularFile(modelDir.resolve("config.json"))) {
configUri = modelDir.resolve("config.json").toUri();
} else if (Files.isRegularFile(modelDir.resolve("model_index.json"))) {
configUri = modelDir.resolve("model_index.json").toUri();
return getHuggingFaceHubConfigUri(modelId);
}
return configUri;
return null;
}

private static URI getHuggingFaceHubConfigUri(String modelId) {
Expand Down Expand Up @@ -291,7 +285,7 @@ static String getAWSGpuMachineType() {
} else if ("9.0".equals(computeCapability)) {
return "p5";
} else {
logger.warn("Could not identify GPU arch " + computeCapability);
logger.warn("Could not identify GPU arch {}", computeCapability);
return null;
}
}
Expand Down
6 changes: 2 additions & 4 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ private void downloadModel() throws ModelNotFoundException, IOException {
artifactName = artifact.getName();
}

private void loadServingProperties() throws ModelException {
private void loadServingProperties() {
if (prop == null) {
Path file = modelDir.resolve("serving.properties");
prop = new Properties();
Expand Down Expand Up @@ -804,9 +804,7 @@ private void configPerModelSettings() throws ModelException {
if (engineName == null) {
engineName = inferEngine();
}
if (LmiUtils.isLMIModel(this)) {
LmiUtils.configureLMIModel(this);
}
LmiUtils.configureLmiModel(this);

StringBuilder sb = new StringBuilder();
for (Map.Entry<Object, Object> entry : prop.entrySet()) {
Expand Down

0 comments on commit e349a05

Please sign in to comment.