Skip to content

Commit

Permalink
Adding support to demo prompt classification with Llama Guard (#5553)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5553

Adding support to load Llama Guard model and run prompt classification task

Reviewed By: cmodi-meta, kirklandsign

Differential Revision: D63148252

fbshipit-source-id: 482559e694da05bdec75b9a2dbd76163c686e47d
(cherry picked from commit 61cb5b0)
  • Loading branch information
Riandy authored and pytorchbot committed Sep 24, 2024
1 parent 9757eda commit 0dab949
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,16 @@ public void run() {
startPos,
MainActivity.this,
false);
} else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) {
String llamaGuardPromptForClassification =
PromptFormat.getFormattedLlamaGuardPrompt(rawPrompt);
ETLogging.getInstance()
.log("Running inference.. prompt=" + llamaGuardPromptForClassification);
mModule.generate(
llamaGuardPromptForClassification,
llamaGuardPromptForClassification.length() + 64,
MainActivity.this,
false);
} else {
ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
mModule.generate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ public enum ModelType {
LLAMA_3,
LLAMA_3_1,
LLAVA_1_5,
LLAMA_GUARD_3,
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public static String getUserPromptTemplate(ModelType modelType) {
switch (modelType) {
case LLAMA_3:
case LLAMA_3_1:
case LLAMA_GUARD_3:
return "<|start_header_id|>user<|end_header_id|>\n"
+ USER_PLACEHOLDER
+ "<|eot_id|>"
Expand Down Expand Up @@ -60,6 +61,7 @@ public static String getStopToken(ModelType modelType) {
switch (modelType) {
case LLAMA_3:
case LLAMA_3_1:
case LLAMA_GUARD_3:
return "<|eot_id|>";
case LLAVA_1_5:
return "</s>";
Expand All @@ -72,4 +74,44 @@ public static String getLlavaPresetPrompt() {
return "A chat between a curious human and an artificial intelligence assistant. The assistant"
+ " gives helpful, detailed, and polite answers to the human's questions. USER: ";
}

public static String getFormattedLlamaGuardPrompt(String userPrompt) {
return getUserPromptTemplate(ModelType.LLAMA_GUARD_3)
.replace(
USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt));
}

public static String getLlamaGuardPresetPrompt() {
return "Task: Check if there is unsafe content in 'User' messages in conversations"
+ " according our safety policy with the below categories.\n"
+ "\n"
+ "<BEGIN UNSAFE CONTENT CATEGORIES>\n"
+ "S1: Violent Crimes. \n"
+ "S2: Non-Violent Crimes. \n"
+ "S3: Sex Crimes. \n"
+ "S4: Child Exploitation. \n"
+ "S5: Defamation. \n"
+ "S6: Specialized Advice. \n"
+ "S7: Privacy. \n"
+ "S8: Intellectual Property. \n"
+ "S9: Indiscriminate Weapons. \n"
+ "S10: Hate. \n"
+ "S11: Self-Harm. \n"
+ "S12: Sexual Content. \n"
+ "S13: Elections. \n"
+ "S14: Code Interpreter Abuse.\n"
+ "<END UNSAFE CONTENT CATEGORIES>\n"
+ "\n"
+ "<BEGIN CONVERSATION>\n"
+ "\n"
+ "User: "
+ USER_PLACEHOLDER
+ "\n"
+ "<END CONVERSATION>\n"
+ "Provide your safety assessment for ONLY THE LAST User message in the above"
+ " conversation:\n"
+ " - First line must read 'safe' or 'unsafe'.\n"
+ " - If unsafe, a second line must include a comma-separated list of violated"
+ " categories.";
}
}

0 comments on commit 0dab949

Please sign in to comment.