From 0dab949c821232508068cf4a7eef27f5b4db6805 Mon Sep 17 00:00:00 2001 From: Riandy Riandy Date: Mon, 23 Sep 2024 16:20:03 -0700 Subject: [PATCH] Adding support to demo prompt classification with Llama Guard (#5553) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5553 Adding support to load Llama Guard model and run prompt classification task Reviewed By: cmodi-meta, kirklandsign Differential Revision: D63148252 fbshipit-source-id: 482559e694da05bdec75b9a2dbd76163c686e47d (cherry picked from commit 61cb5b0a57acae788f43faf5a0399b1d6cacb004) --- .../executorchllamademo/MainActivity.java | 10 +++++ .../executorchllamademo/ModelType.java | 1 + .../executorchllamademo/PromptFormat.java | 42 +++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 4d81ec8ae5..524b4fbc8a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -704,6 +704,16 @@ public void run() { startPos, MainActivity.this, false); + } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) { + String llamaGuardPromptForClassification = + PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt); + ETLogging.getInstance() + .log("Running inference.. prompt=" + llamaGuardPromptForClassification); + mModule.generate( + llamaGuardPromptForClassification, + llamaGuardPromptForClassification.length() + 64, + MainActivity.this, + false); } else { ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); mModule.generate( diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java index 91e84be059..a241ca3d52 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java @@ -12,4 +12,5 @@ public enum ModelType { LLAMA_3, LLAMA_3_1, LLAVA_1_5, + LLAMA_GUARD_3, } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 36e738c3d0..14cf38e669 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -33,6 +33,7 @@ public static String getUserPromptTemplate(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_GUARD_3: return "<|start_header_id|>user<|end_header_id|>\n" + USER_PLACEHOLDER + "<|eot_id|>" @@ -60,6 +61,7 @@ public static String getStopToken(ModelType modelType) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: + case LLAMA_GUARD_3: return "<|eot_id|>"; case LLAVA_1_5: return ""; @@ -72,4 +74,44 @@ public static String getLlavaPresetPrompt() { return "A chat between a curious human and an artificial intelligence assistant. The assistant" + " gives helpful, detailed, and polite answers to the human's questions. USER: "; } + + public static String getFormattedLlamaGuardPrompt(String userPrompt) { + return getUserPromptTemplate(ModelType.LLAMA_GUARD_3) + .replace( + USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt)); + } + + public static String getLlamaGuardPresetPrompt() { + return "Task: Check if there is unsafe content in 'User' messages in conversations" + + " according our safety policy with the below categories.\n" + + "\n" + + "\n" + + "S1: Violent Crimes. \n" + + "S2: Non-Violent Crimes. \n" + + "S3: Sex Crimes. \n" + + "S4: Child Exploitation. \n" + + "S5: Defamation. \n" + + "S6: Specialized Advice. \n" + + "S7: Privacy. \n" + + "S8: Intellectual Property. \n" + + "S9: Indiscriminate Weapons. \n" + + "S10: Hate. \n" + + "S11: Self-Harm. \n" + + "S12: Sexual Content. \n" + + "S13: Elections. \n" + + "S14: Code Interpreter Abuse.\n" + + "\n" + + "\n" + + "\n" + + "\n" + + "User: " + + USER_PLACEHOLDER + + "\n" + + "\n" + + "Provide your safety assessment for ONLY THE LAST User message in the above" + + " conversation:\n" + + " - First line must read 'safe' or 'unsafe'.\n" + + " - If unsafe, a second line must include a comma-separated list of violated" + + " categories."; + } }