From e2a0db017aa99cf51042c4a36f8c3cab66f2ebf0 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Fri, 22 Mar 2024 10:07:38 -0700 Subject: [PATCH] [lmi] use hf token to get model config for gated/private models (#1658) --- .../main/java/ai/djl/serving/util/ConfigManager.java | 5 ++++- wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java | 12 +++++++++++- .../test/java/ai/djl/serving/wlm/ModelInfoTest.java | 1 + wlm/src/test/resources/local-hf-model/config.json | 3 +++ 4 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 wlm/src/test/resources/local-hf-model/config.json diff --git a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java index f171b55ed..4f4f2db46 100644 --- a/serving/src/main/java/ai/djl/serving/util/ConfigManager.java +++ b/serving/src/main/java/ai/djl/serving/util/ConfigManager.java @@ -447,7 +447,10 @@ public String dumpConfigurations() { .append("\nEnvironment variables:"); for (Map.Entry entry : System.getenv().entrySet()) { String key = entry.getKey(); - if (key.startsWith("SERVING") + // Do not log HF_TOKEN value + if ("HF_TOKEN".equals(key)) { + sb.append("\n\t").append(key).append(": ***"); + } else if (key.startsWith("SERVING") || key.startsWith("PYTHON") || key.startsWith("DJL_") || key.startsWith("HF_") diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index a225ee8e3..c8f92b221 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -31,6 +31,7 @@ import java.io.InputStreamReader; import java.net.HttpURLConnection; import java.net.URI; +import java.net.URLConnection; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; @@ -149,8 +150,12 @@ public static URI generateHuggingFaceConfigUri(ModelInfo modelInfo, String return null; } + String hfToken = Utils.getenv("HF_TOKEN"); configUri = URI.create("https://huggingface.co/" + modelId + "/raw/main/config.json"); HttpURLConnection configUrl = (HttpURLConnection) configUri.toURL().openConnection(); + if (hfToken != null) { + configUrl.setRequestProperty("Authorization", "Bearer " + hfToken); + } // stable diffusion models have a different file name with the config... sometimes if (HttpURLConnection.HTTP_OK != configUrl.getResponseCode()) { configUri = @@ -173,7 +178,12 @@ private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo if (modelConfigUri == null) { return null; } - try (InputStream is = modelConfigUri.toURL().openStream()) { + URLConnection configConnection = modelConfigUri.toURL().openConnection(); + if (Utils.getenv("HF_TOKEN") != null) { + configConnection.setRequestProperty( + "Authorization", "Bearer " + Utils.getenv("HF_TOKEN")); + } + try (InputStream is = configConnection.getInputStream()) { return JsonUtils.GSON.fromJson(Utils.toString(is), HuggingFaceModelConfig.class); } } catch (IOException | JsonSyntaxException e) { diff --git a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java index c23020b48..527028367 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java @@ -259,6 +259,7 @@ public void testInferLMIEngine() throws IOException, ModelException { put("openai-community/gpt2", "vllm"); put("tiiuae/falcon-7b", "lmi-dist"); put("mistralai/Mistral-7B-v0.1", "vllm"); + put("src/test/resources/local-hf-model", "vllm"); } }; Path modelStore = Paths.get("build/models"); diff --git a/wlm/src/test/resources/local-hf-model/config.json b/wlm/src/test/resources/local-hf-model/config.json new file mode 100644 index 000000000..575704404 --- /dev/null +++ b/wlm/src/test/resources/local-hf-model/config.json @@ -0,0 +1,3 @@ +{ + "model_type": "gpt2" +}