Skip to content

Commit

Permalink
add <|eom_id|> to stopTokens
Browse files Browse the repository at this point in the history
  • Loading branch information
cmodi-meta committed Dec 6, 2024
1 parent 935a923 commit 74526a2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ constructor(
private var sequenceLengthKey: String = "seq_len"

override fun onResult(p0: String?) {
if (p0.equals(PromptFormatLocal.getStopToken(modelName))) {
if (PromptFormatLocal.getStopTokens(modelName).any { it == p0 }) {
onResultComplete = true
return
}
Expand Down Expand Up @@ -62,8 +62,9 @@ constructor(
PromptFormatLocal.getTotalFormattedPrompt(params.messages(), modelName)

// Developer can pass in their sequence length but if not then it will default to a
// particular dynamic value. This is to ensure enough value is provided to give a reasonably complete response.
// 0.75 is the approximate words per token. And 64 is buffer for tokens for generate response.
// particular dynamic value. This is to ensure enough value is provided to give a reasonably
// complete response. 0.75 is the approximate words per token. And 64 is buffer for tokens
// for generate response.
val seqLength =
params._additionalQueryParams().values(sequenceLengthKey).lastOrNull()?.toInt()
?: ((formattedPrompt.length * 0.75) + 64).toInt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ object PromptFormatLocal {
}
}

fun getStopToken(modelName: String?): String {
fun getStopTokens(modelName: String?): List<String> {
return when (modelName) {
"LLAMA_3",
"LLAMA_3_1",
"LLAMA_3_2",
"LLAMA_GUARD_3" -> "<|eot_id|>"
else -> ""
"LLAMA_GUARD_3" -> listOf("<|eot_id|>", "<|eom_id|>")
else -> listOf("")
}
}

Expand Down

0 comments on commit 74526a2

Please sign in to comment.