Skip to content

Commit

Permalink
[lmi] always configure lmi model instead of just when engine is missing
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Apr 9, 2024
1 parent 88f1d71 commit da115ff
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 54 deletions.
22 changes: 6 additions & 16 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiConfigRecommender.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ static void configure(
String features = Utils.getEnvOrSystemProperty("SERVING_FEATURES");
setDynamicBatch(lmiProperties, modelConfig, modelInfo, features);
setRollingBatch(lmiProperties, modelConfig, features);
setEngine(lmiProperties, modelConfig, features);
setMPIMode(lmiProperties, modelConfig, features);
setTensorParallelDegree(lmiProperties);
setRollingBatchSize(lmiProperties);
}

private static void setRollingBatch(
Expand All @@ -92,25 +93,18 @@ private static void setRollingBatch(
lmiProperties.setProperty("option.rolling_batch", rollingBatch);
}

private static void setEngine(
private static void setMPIMode(
Properties lmiProperties,
LmiUtils.HuggingFaceModelConfig modelConfig,
String features) {
if (lmiProperties.containsKey("engine")) {
return;
}
String engine = "Python";
String rollingBatch = lmiProperties.getProperty("option.rolling_batch");
if ("lmi-dist".equals(rollingBatch) || "trtllm".equals(rollingBatch)) {
engine = "MPI";
lmiProperties.setProperty("option.mpi_mode", "true");
}
// TODO TrtLLM python backend: Change it once TrtLLM supports T5 with inflight batching.
if (isT5TrtLLM(modelConfig, features)) {
engine = "MPI";
lmiProperties.setProperty("option.mpi_mode", "true");
}
lmiProperties.setProperty("engine", engine);
}

private static void setTensorParallelDegree(Properties lmiProperties) {
Expand Down Expand Up @@ -144,7 +138,7 @@ private static void setDynamicBatch(
}
}

static void setRollingBatchSize(Properties lmiProperties) {
private static void setRollingBatchSize(Properties lmiProperties) {
if (lmiProperties.containsKey("option.max_rolling_batch_size")) {
return;
}
Expand All @@ -153,12 +147,8 @@ static void setRollingBatchSize(Properties lmiProperties) {
if ("vllm".equals(rollingBatch) || "lmi-dist".equals(rollingBatch)) {
rollingBatchSize = 256;
}
if ("trtllm".equals(rollingBatch)
|| ("auto".equals(rollingBatch)
&& isTrtLLMEnabled(Utils.getEnvOrSystemProperty("SERVING_FEATURES")))) {
if (lmiProperties.containsKey("option.max_num_tokens")) {
rollingBatchSize = 256;
}
if ("trtllm".equals(rollingBatch) && lmiProperties.containsKey("option.max_num_tokens")) {
rollingBatchSize = 256;
}
lmiProperties.setProperty(
"option.max_rolling_batch_size", String.valueOf(rollingBatchSize));
Expand Down
57 changes: 30 additions & 27 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,9 @@ public final class LmiUtils {

private LmiUtils() {}

static String inferLmiEngine(ModelInfo<?, ?> modelInfo) throws ModelException {
static void configureLMIModel(ModelInfo<?, ?> modelInfo) throws ModelException {
Properties prop = modelInfo.getProperties();
HuggingFaceModelConfig modelConfig = getHuggingFaceModelConfig(modelInfo);
if (modelConfig == null) {
String engineName = isTrtLLMRollingBatch(prop) ? "MPI" : "Python";
logger.info("No config.json found, use {} engine.", engineName);
return engineName;
}
LmiConfigRecommender.configure(modelInfo, prop, modelConfig);
logger.info(
"Detected engine: {}, rolling_batch: {}, tensor_parallel_degree {}, for modelType:"
Expand All @@ -68,7 +63,11 @@ static String inferLmiEngine(ModelInfo<?, ?> modelInfo) throws ModelException {
prop.getProperty("option.rolling_batch"),
prop.getProperty("option.tensor_parallel_degree"),
modelConfig.getModelType());
return prop.getProperty("engine");
}

static boolean isLMIModel(ModelInfo<?, ?> modelInfo) {
String modelId = modelInfo.getProperties().getProperty("option.model_id");
return null != generateHuggingFaceConfigUri(modelInfo, modelId);
}

static boolean isTrtLLMRollingBatch(Properties properties) {
Expand Down Expand Up @@ -140,17 +139,11 @@ static void convertTrtLLM(ModelInfo<?, ?> info) throws IOException {
* @param modelInfo the model object
* @param modelId the model id
* @return the Huggingface config.json file URI
* @throws ModelException if the model not found
* @throws IOException if failed read from huggingface hub
*/
public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String modelId)
throws ModelException, IOException {
public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String modelId) {
URI configUri = null;
Path modelDir = modelInfo.modelDir;
if (modelId != null && modelId.startsWith("s3://")) {
// This is definitely suboptimal, but for the majority of cases we need to download this
// s3 model eventually, so it is not the worst thing to download it now.
modelInfo.downloadS3();
Path downloadDir = modelInfo.downloadDir;
if (Files.isRegularFile(downloadDir.resolve("config.json"))) {
configUri = downloadDir.resolve("config.json").toUri();
Expand All @@ -172,19 +165,7 @@ public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String
}
return null;
}

String hfToken = Utils.getenv("HF_TOKEN");
configUri = URI.create("https://huggingface.co/" + modelId + "/raw/main/config.json");
HttpURLConnection configUrl = (HttpURLConnection) configUri.toURL().openConnection();
if (hfToken != null) {
configUrl.setRequestProperty("Authorization", "Bearer " + hfToken);
}
// stable diffusion models have a different file name with the config... sometimes
if (HttpURLConnection.HTTP_OK != configUrl.getResponseCode()) {
configUri =
URI.create(
"https://huggingface.co/" + modelId + "/raw/main/model_index.json");
}
configUri = getHuggingFaceHubConfigUri(modelId);
} else if (Files.isRegularFile(modelDir.resolve("config.json"))) {
configUri = modelDir.resolve("config.json").toUri();
} else if (Files.isRegularFile(modelDir.resolve("model_index.json"))) {
Expand All @@ -193,6 +174,28 @@ public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String
return configUri;
}

private static URI getHuggingFaceHubConfigUri(String modelId) {
String[] possibleConfigFiles = {"config.json", "model_index.json"};
String hubToken = Utils.getEnvOrSystemProperty("HF_TOKEN");
for (String configFile : possibleConfigFiles) {
try {
URI configUri =
URI.create("https://huggingface.co/" + modelId + "/raw/main/" + configFile);
HttpURLConnection configUrl =
(HttpURLConnection) configUri.toURL().openConnection();
if (hubToken != null) {
configUrl.setRequestProperty("Authorization", "Bearer " + hubToken);
}
if (HttpURLConnection.HTTP_OK == configUrl.getResponseCode()) {
return configUri;
}
} catch (IOException e) {
logger.warn("Hub config file {} does not exist for model {}.", configFile, modelId);
}
}
return null;
}

private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo<?, ?> modelInfo)
throws ModelException {
String modelId = modelInfo.prop.getProperty("option.model_id");
Expand Down
17 changes: 7 additions & 10 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ public void initialize() throws IOException, ModelException {
loadServingProperties();
downloadS3();
eventManager.onModelDownloaded(this, downloadDir);
configPerModelSettings();
downloadDraftModel();

long duration = (System.nanoTime() - begin) / 1000;
Expand Down Expand Up @@ -678,10 +679,8 @@ private String inferEngine() throws ModelException {
}
}

if (isTorchServeModel()) {
if (isPythonModel(prefix)) {
return "Python";
} else if (isPythonModel(prefix)) {
return LmiUtils.inferLmiEngine(this);
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".pt"))
|| Files.isRegularFile(modelDir.resolve("model.pt"))) {
return "PyTorch";
Expand Down Expand Up @@ -734,7 +733,8 @@ private boolean isPythonModel(String prefix) {
return Files.isRegularFile(modelDir.resolve("model.py"))
|| Files.isRegularFile(modelDir.resolve(prefix + ".py"))
|| prop.getProperty("option.model_id") != null
|| Files.isRegularFile(modelDir.resolve("config.json"));
|| Files.isRegularFile(modelDir.resolve("config.json"))
|| isTorchServeModel();
}

private void downloadModel() throws ModelNotFoundException, IOException {
Expand Down Expand Up @@ -776,8 +776,6 @@ private void loadServingProperties() throws ModelException {
prop.putIfAbsent("option." + key, value);
}
}
configPerModelSettings();
eventManager.onModelConfigured(this);
}
}

Expand Down Expand Up @@ -806,10 +804,8 @@ private void configPerModelSettings() throws ModelException {
if (engineName == null) {
engineName = inferEngine();
}
// TODO: capture this in the LmiConfigRecommender.configure method once we refactor that to
// run always, not just when engine is missing
if (LmiUtils.isRollingBatchEnabled(this.getProperties())) {
LmiConfigRecommender.setRollingBatchSize(this.getProperties());
if (LmiUtils.isLMIModel(this)) {
LmiUtils.configureLMIModel(this);
}

StringBuilder sb = new StringBuilder();
Expand Down Expand Up @@ -845,6 +841,7 @@ private void configPerModelSettings() throws ModelException {
prop.get("option.mpi_mode"),
prop.get("option.entryPoint"),
sb);
eventManager.onModelConfigured(this);
}

void checkAvailableMemory(Device device) throws IOException {
Expand Down
4 changes: 3 additions & 1 deletion wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ public void testInferLMIEngine() throws IOException, ModelException {
writer.write("option.model_id=invalid-model-id");
}
model = new ModelInfo<>("build/models/lmi_test_model");
Assert.assertThrows(model::initialize);
model.initialize();
assertEquals(model.getEngineName(), "Python");
assertEquals(model.getProperties().getProperty("option.rolling_batch"), null);

// TODO: no good way to test trtllm now since it requires converting the model
}
Expand Down

0 comments on commit da115ff

Please sign in to comment.