Skip to content

Commit

Permalink
[lmi] use hf token to get model config for gated/private models (#1658)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Mar 22, 2024
1 parent 4e3f004 commit e2a0db0
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
5 changes: 4 additions & 1 deletion serving/src/main/java/ai/djl/serving/util/ConfigManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,10 @@ public String dumpConfigurations() {
.append("\nEnvironment variables:");
for (Map.Entry<String, String> entry : System.getenv().entrySet()) {
String key = entry.getKey();
if (key.startsWith("SERVING")
// Do not log HF_TOKEN value
if ("HF_TOKEN".equals(key)) {
sb.append("\n\t").append(key).append(": ***");
} else if (key.startsWith("SERVING")
|| key.startsWith("PYTHON")
|| key.startsWith("DJL_")
|| key.startsWith("HF_")
Expand Down
12 changes: 11 additions & 1 deletion wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URLConnection;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -149,8 +150,12 @@ public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String
return null;
}

String hfToken = Utils.getenv("HF_TOKEN");
configUri = URI.create("https://huggingface.co/" + modelId + "/raw/main/config.json");
HttpURLConnection configUrl = (HttpURLConnection) configUri.toURL().openConnection();
if (hfToken != null) {
configUrl.setRequestProperty("Authorization", "Bearer " + hfToken);
}
// stable diffusion models have a different file name with the config... sometimes
if (HttpURLConnection.HTTP_OK != configUrl.getResponseCode()) {
configUri =
Expand All @@ -173,7 +178,12 @@ private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo<?, ?>
if (modelConfigUri == null) {
return null;
}
try (InputStream is = modelConfigUri.toURL().openStream()) {
URLConnection configConnection = modelConfigUri.toURL().openConnection();
if (Utils.getenv("HF_TOKEN") != null) {
configConnection.setRequestProperty(
"Authorization", "Bearer " + Utils.getenv("HF_TOKEN"));
}
try (InputStream is = configConnection.getInputStream()) {
return JsonUtils.GSON.fromJson(Utils.toString(is), HuggingFaceModelConfig.class);
}
} catch (IOException | JsonSyntaxException e) {
Expand Down
1 change: 1 addition & 0 deletions wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ public void testInferLMIEngine() throws IOException, ModelException {
put("openai-community/gpt2", "vllm");
put("tiiuae/falcon-7b", "lmi-dist");
put("mistralai/Mistral-7B-v0.1", "vllm");
put("src/test/resources/local-hf-model", "vllm");
}
};
Path modelStore = Paths.get("build/models");
Expand Down
3 changes: 3 additions & 0 deletions wlm/src/test/resources/local-hf-model/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"model_type": "gpt2"
}

0 comments on commit e2a0db0

Please sign in to comment.