Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer lmi engine #623

Merged
merged 12 commits into from
Apr 14, 2023
Merged

Infer lmi engine #623

merged 12 commits into from
Apr 14, 2023

Conversation

siddvenk
Copy link
Contributor

Description

This PR adds support for inferring the specific Python engine to use. This logic should only be called when users don't specify engine in serving.properties, and we can reasonably assume the model is intended to be run with a Python backend.

Use Cases covered:

  • User only specifies HF_MODEL_ID environment variable (no serving properties, model artifacts, or user code)
  • User provides serving properties with model id (either hf hub id or s3 url), but does not specify engine
  • model_id is not provided in any form, but model artifacts are present in model_dir

This does not support the use case where users provide their own code that is expected to be invoked via the hugging face inference toolkit. There's no special error handling for that now, it will just fail when the PyProcess tries to load the handler and invoke it.

The logic for inferring the backend is largely copied over from the logic in the PySDK.

@siddvenk siddvenk requested review from zachgk, frankfliu and a team as code owners April 13, 2023 00:54
return "DeepSpeed";
}

if (!isTensorParallelSupported(numAttentionHeads, tensorParallelDegree)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe DS or FT would have some mechanism from their end to decide how to do model sharding. I would suggest to not check this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for DS, they will throw an exception if this check fails. But really the only practical examples of this we have seen is gpt2-xl.

In the future it's possible that DS and FT change that behavior and can actually accommodate such a model. At that point this method would become incorrect.

I can remove this, since it's going to be validated by the engine anyways. But the benefit of doing it this way is that we don't recommend say gpt2-xl to run with DeepSpeed with TP when we know it won't work.

|| Files.isRegularFile(modelDir.resolve(prefix + ".py"))
|| Utils.getEnvOrSystemProperty("HF_MODEL_ID") != null
|| Files.isRegularFile(modelDir.resolve("config.json"))
|| prop.containsKey("option.s3url")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other engine can support model_id and s3url as well.
If user defined a option.model_id, we can assume they can add engine as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it - i removed those checks here.

BufferedReader reader =
new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
return JsonUtils.GSON.fromJson(reader, JsonElement.class).getAsJsonObject();
} catch (IOException e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also catch JsonSyntaxException as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

try (InputStream is = modelConfigUri.toURL().openStream();
BufferedReader reader =
new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
return JsonUtils.GSON.fromJson(reader, JsonElement.class).getAsJsonObject();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better create a class to hold the config (we only need define the field we care about)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 256 to 267
Path deepspeedLocation = Paths.get("/usr/local/bin/deepspeed");
boolean deepspeedExisted = true;
if (!Files.exists(deepspeedLocation)) {
Files.createDirectories(deepspeedLocation);
deepspeedExisted = false;
}
Path fastertransformerLocation = Paths.get("/usr/local/backends/fastertransformer");
boolean fastertransformerExisted = true;
if (!Files.exists(fastertransformerLocation)) {
Files.createDirectories(fastertransformerLocation);
fastertransformerExisted = false;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this for the unit tests, but not a big fan of it. We can instead test this with some integration tests?

// This represents the config of huggingface models NLP models as well
// as the config of diffusers models. The config is different for both, but for
// now we can leverage a single class since we don't need too much information from the config.
static class HuggingFaceModelConfig {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static class HuggingFaceModelConfig {
static final class HuggingFaceModelConfig {

throw e;
Gson gson =
JsonUtils.builder()
.setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this since we already use @SerializedName?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left over from my testing - this should be removed good catch

} else if (modelId != null) {
prop.put("option.modelId", modelId);
configUri = URI.create("https://huggingface.co/" + modelId + "/raw/main/config.json");
HttpURLConnection configUrl = (HttpURLConnection) configUri.toURL().openConnection();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can consider use "OPTION" method instead of GET. I think it should work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll explore that and add as follow up if i get it to work.

@siddvenk siddvenk merged commit d042a46 into deepjavalibrary:master Apr 14, 2023
@siddvenk siddvenk deleted the infer-lmi-engine branch June 13, 2023 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants