Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer lmi engine #623

Merged
merged 12 commits into from
Apr 14, 2023
Prev Previous commit
Next Next commit
also support just tar.gz model format without explicit hf hub id
  • Loading branch information
siddvenk committed Apr 12, 2023
commit accc615c6ada62c6ab9661e5fe9c903b1515ccc8
61 changes: 37 additions & 24 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;

import jdk.jshell.execution.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -594,6 +595,9 @@ private String inferEngine() throws ModelNotFoundException {
return "PaddlePaddle";
} else if (Files.isRegularFile(modelDir.resolve(prefix + ".json"))) {
return "XGBoost";
} else if (Utils.getEnvOrSystemProperty("HF_MODEL_ID") != null
|| Files.isRegularFile(modelDir.resolve("config.json"))) {
return inferLMIEngine();
} else {
try {
if (Utils.getCurrentEpoch(modelDir, prefix) >= 0) {
Expand All @@ -603,54 +607,64 @@ private String inferEngine() throws ModelNotFoundException {
} catch (IOException e) {
logger.warn("Failed search parameter files in folder: " + modelDir, e);
}
String huggingFaceHubModelId = Utils.getEnvOrSystemProperty("HF_MODEL_ID");
if (huggingFaceHubModelId == null) {
huggingFaceHubModelId = prop.get("option.modelId").toString();
}
if (huggingFaceHubModelId != null && !huggingFaceHubModelId.startsWith("s3://")) {
return inferLMIEngine();
}
}
throw new ModelNotFoundException("Failed to detect engine of the model: " + modelDir);
}

private String inferLMIEngine(){
private String inferLMIEngine() {
String huggingFaceHubModelId = Utils.getEnvOrSystemProperty("HF_MODEL_ID");
String huggingFaceTask = Utils.getEnvOrSystemProperty("HF_TASK");
String modelConfigUrl =
"https://huggingface.co/" + huggingFaceHubModelId + "/raw/main/config.json";

String modelConfigUri;
if (huggingFaceHubModelId == null) {
modelConfigUri = modelDir.resolve("config.json").toString();
} else {
modelConfigUri =
"https://huggingface.co/"
+ huggingFaceHubModelId
+ "/raw/main/config.json";
}
JsonObject modelConfig;
try (InputStream is = new URL(modelConfigUrl).openStream();
try (InputStream is = new URL(modelConfigUri).openStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(is))) {
modelConfig = JsonUtils.GSON.fromJson(reader, JsonElement.class).getAsJsonObject();
} catch (IOException e) {
logger.error("Could not read model config for {} from huggingface hub", huggingFaceHubModelId, e);
logger.error(
"Could not read model config for {} from huggingface hub",
huggingFaceHubModelId,
e);
return "Python";
siddvenk marked this conversation as resolved.
Show resolved Hide resolved
}

String modelType = modelConfig.get("model_type").getAsString();
long numAttentionHeads;
long numAttentionHeads = Long.MAX_VALUE;
// All of these are valid in the config.json file for the number of attention heads
if (modelConfig.has("num_attention_heads")) {
numAttentionHeads = modelConfig.get("num_attention_heads").getAsLong();
} else if (modelConfig.has("num_heads")) {
numAttentionHeads = modelConfig.get("num_heads").getAsLong();
} else if (modelConfig.has("n_head")) {
numAttentionHeads = modelConfig.get("n_head").getAsLong();
} else {
numAttentionHeads = 0;
}
int tensorParallelDegree;

int tensorParallelDegree = CudaUtils.getGpuCount();
if (Utils.getEnvOrSystemProperty("TENSOR_PARALLEL_DEGREE") != null) {
tensorParallelDegree = Integer.parseInt(Utils.getEnvOrSystemProperty("TENSOR_PARALLEL_DEGREE"));
tensorParallelDegree =
Integer.parseInt(Utils.getEnvOrSystemProperty("TENSOR_PARALLEL_DEGREE"));
} else if (prop.get("option.tensor_parallel_degree") != null) {
tensorParallelDegree =
Integer.parseInt(prop.get("option.tensor_parallel_degree").toString());
} else {
tensorParallelDegree = CudaUtils.getGpuCount();
}
logger.info("moddel config: {}", modelConfig);
logger.info("tensor parallel degree: {}", tensorParallelDegree);
logger.info("num attention heads: {}", numAttentionHeads);

prop.put("option.tensor_parallel_degree", tensorParallelDegree);
prop.put("option.task", huggingFaceTask);
prop.put("option.model_id", huggingFaceHubModelId);
if (tensorParallelDegree > 0) {
prop.put("option.tensor_parallel_degree", tensorParallelDegree);
}
String huggingFaceTask = Utils.getEnvOrSystemProperty("HF_TASK");
if (huggingFaceTask != null) {
prop.put("option.task", huggingFaceTask);
}

if (!isTensorParallelSupported(numAttentionHeads, tensorParallelDegree)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe DS or FT would have some mechanism from their end to decide how to do model sharding. I would suggest to not check this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for DS, they will throw an exception if this check fails. But really the only practical examples of this we have seen is gpt2-xl.

In the future it's possible that DS and FT change that behavior and can actually accommodate such a model. At that point this method would become incorrect.

I can remove this, since it's going to be validated by the engine anyways. But the benefit of doing it this way is that we don't recommend say gpt2-xl to run with DeepSpeed with TP when we know it won't work.

return "Python";
Expand All @@ -672,7 +686,6 @@ private boolean isDeepSpeedRecommended(String modelType) {
return DEEPSPEED_MODELS.contains(modelType);
}


private boolean isTensorParallelSupported(long numAttentionHeads, int tensorParallelDegree) {
return tensorParallelDegree > 0 && numAttentionHeads % tensorParallelDegree == 0;
}
Expand Down
12 changes: 6 additions & 6 deletions wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -252,33 +252,33 @@ public void testInferLMIEngine() throws IOException, ModelException {
Files.createDirectories(modelDir);

System.setProperty("HF_MODEL_ID", "gpt2");
System.setProperty("HF_TASK", "text-generation");
System.setProperty("TENSOR_PARALLEL_DEGREE", "4");
ModelInfo<Input, Output> model = new ModelInfo<>("build/models/test_model");
model.initialize();
assertEquals(model.getEngineName(), "DeepSpeed");

System.setProperty("HF_MODEL_ID", "google/flan-t5-xl");
System.setProperty("HF_TASK", "text2text-generation");
model = new ModelInfo<>("build/models/test_model");
model.initialize();
assertEquals(model.getEngineName(), "FasterTransformer");

System.setProperty("HF_MODEL_ID", "gpt2-xl");
System.setProperty("HF_TASK", "text-generation");
model = new ModelInfo<>("build/models/test_model");
model = new ModelInfo<>("build/models/test_model");
model.initialize();
assertEquals(model.getEngineName(), "Python");

System.setProperty("HF_MODEL_ID", "Salesforce/codegen-6B-mono");
System.setProperty("HF_TASK", "text-generation");
model = new ModelInfo<>("build/models/test_model");
model.initialize();
assertEquals(model.getEngineName(), "Python");

System.setProperty("HF_MODEL_ID", "invalid-model-id");
model = new ModelInfo<>("build/models/test_model");
model.initialize();
assertEquals(model.getEngineName(), "Python");

System.clearProperty("HF_MODEL_ID");
System.clearProperty("HF_TASK");
System.clearProperty("TENSOR_PARALLEL_DEGREE");

}
}