Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trt-llm engine build step during model initialization #1235

Merged
merged 9 commits into from
Nov 1, 2023

Conversation

rohithkrn
Copy link
Contributor

@rohithkrn rohithkrn commented Oct 30, 2023

Description

Add TRT-LLM engine build step during model initialization

TODO: Tests

@rohithkrn rohithkrn requested review from zachgk, frankfliu and a team as code owners October 30, 2023 00:21
Copy link
Contributor

@lanking520 lanking520 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your logic doesn't cover a condition that the model is SageMaker uncompressed model: model is saved to /opt/ml/models (read only). The model can be a triton repo or HF standard model. This case you should scan the model_dir and find out if there are model files there.

@@ -165,6 +165,8 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
} else if ("nc".equals(manager.getDevice().getDeviceType())
&& pyEnv.getTensorParallelDegree() > 0) {
entryPoint = "djl_python.transformers_neuronx";
} else if ("TRT-LLM".equals(Utils.getenv("LMI_BACKEND"))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better check option.rolling_batch=trtllm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that doesn't cover the case when customers do not want to use rolling batch. Also, I think there's an effort to make rolling_batch as boolean property

serving/docker/partition/trt_llm_partition.py Show resolved Hide resolved
wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java Outdated Show resolved Hide resolved
@rohithkrn
Copy link
Contributor Author

Your logic doesn't cover a condition that the model is SageMaker uncompressed model: model is saved to /opt/ml/models (read only). The model can be a triton repo or HF standard model. This case you should scan the model_dir and find out if there are model files there.

This change relies on model_id set in serving.properties or env var. For an uncompressed model, what would model_id look like?

@@ -463,6 +471,10 @@ public void initialize() throws IOException, ModelException {
downloadModel();
loadServingProperties();
downloadS3();
isTrtLlmBackend = "TRT-LLM".equals(Utils.getenv("LMI_BACKEND"));
if (isTrtLlmBackend) {
initTrtLlmModel();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move these changes to the python engine. Probably the PyModel.load() function. The ModelInfo is a higher abstraction that shouldn't be modified by these changes. It should also be easier to test because you can just test it by loading and predicting with the Python Engine using the standard DJL predictor API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree to your point. Will refactor this in a separate PR as I don't have the bandwidth to refactor and test it currently.

@rohithkrn
Copy link
Contributor Author

rohithkrn commented Oct 31, 2023

CI seems to be flaky.
Tests pass locally on mac with PR branch.
Master branch also fails on ubuntu locally

@lanking520
Copy link
Contributor

lanking520 commented Nov 1, 2023

I fixed the pmd rules, please rebase

@@ -150,6 +151,21 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
pyEnv.setFailOnInitialize(false);
}

// Handle TRT-LLM
if ("TRT-LLM".equals(Utils.getenv("LMI_BACKEND")) || Boolean.parseBoolean(getProperty("option.trt_llm"))) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env var is always set in the container and we don't want to introduce other option. I will update this,

String modelId = trtLlmRepoDir.toAbsolutePath().toString();
setProperty("model_id", modelId);
pyEnv.addParameter("model_id", modelId);
entryPoint = "djl_python.tensorrt_llm";
Copy link
Contributor Author

@rohithkrn rohithkrn Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this overrides user set entryPoint

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. If that isn't desired, you can move it back or set only if entryPoint is null

@rohithkrn rohithkrn merged commit a58a735 into deepjavalibrary:master Nov 1, 2023
5 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants