-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : speed-up grammar sampling #4218
Comments
#3980 and this suggestion might also help a bit #3980 (comment) I would have expected the compiler to optimize it straight away 🤷🏻 |
Would an integration of Outlines help? Like they are doing with vLLM: dottxt-ai/outlines#163 |
@ExtReMLapin This copy is used only in the @gottlike An efficient low-level solution as the one we currently have seems like a better approach to me. |
I noticed that inference gets at some point exponentially slower when there are a lot of deeply nested, but open grammars. With open I mean a lot of different possibilities. As example I am trying to work on PydanticModel -> JsonSchema -> Grammar and when the model outputs a list of nested subobjects this effect comes when the list is long and at some point it gets stuck. |
@shroominic on my end it just gets slower the longer it is in printing the json array, no nested objects. |
I found similar exponential slowdown as mentioned by @shroominic for my use case, which is to generate code in a language similar to OCaml. The speed of generation was very fast at the first 200 tokens but increased to more than 400 seconds per token as I approach 300 tokens. I plotted the grammar stack size and duration per token over time and found stack size to be the main factor in the slow down. The number of grammar stacks can go up 800K for my grammar. I'm not very familiar with the grammar sampling algorithm used in llama.cpp but I suspect it's exponential in the length of the parsed string. Polynomially bounded parsing algorithms like the Earley parser might help avoid the exponential blowup. Grammar
Prompt### Option ###
# Represent values that may or may not exist. #
type Option =
+ Some(?)
+ None
in
# Compare if two Options are equal #
# equal: ((?, ?) -> Bool) -> (Option, Option) -> Bool #
let equal: ((?, ?) -> Bool) -> (Option, Option) -> Bool =
fun eq, os ->
case os
| Some(x), Some(y) => eq(x, y)
| None, None => True
| _, _ => False
end
in
### Result ###
# A Result is either Ok meaning the computation succeeded, #
# or it is an Err meaning that there was some failure. #
type Result =
+ Ok(a)
+ Err(b)
in
# Compare if two Results are equal #
# equal: ((a, a) -> Bool) -> (Result, Result) -> Bool #
let equal: ((a, a) -> Bool) -> (Result, Result) -> Bool =
fun eq, rs ->
case rs
| Ok(e1), Ok(e2) => eq(e1, e2)
| Error(e1), Error(e2) => e1 $== e2
| _ => false
end
in
### JSON ###
# This module helps you convert between Hazel values and JSON values. #
# A JSON value type #
type Value =
+ Object([(String, Value)])
+ Array([Value])
+ Str(String)
+ Number(Float)
+ Boolean(Bool)
+ Null
in
# Check if two JSON values are equal #
# equal : (Value,Value) -> Bool #
let equal : (Value,Value) -> Bool =
fun a, b ->
case (a, b)
| Object(o1), Object(o2) => List.equal(
fun (s1, v1), (s2, v2) ->
s1 $== s2 && equal(v1, v2), o1, o2)
| Array(a1), Array(a2) => List.equal(equal, a1, a2)
| Str(s1), Str(s2) => s1 $== s2
| Number(n1), Number(n2) => n1 ==. n2
| Boolean(b1), Boolean(b2) => if b1 then b2 else !b2
| Null, Null => true
| _ => false
end
in
# JSON Encoder #
# Convert a string to a JSON string #
# value_of_string : String -> Value #
let value_of_string : String -> Value =
fun s -> Str(s)
in
# Convert an integer to a JSON integer #
# value_of_int : Int -> Value #
let value_of_int : Int -> Value =
fun i -> Number(float_of_int(i))
in
# Convert a float to a JSON float #
# value_of_float : Float -> Value #
let value_of_float : Float -> Value =
fun f -> Number(f)
in
# Convert a boolean to a JSON boolean #
# value_of_bool : Bool -> Value #
let value_of_bool : Bool -> Value =
fun b -> if b then Boolean(true) else Boolean(false)
in
# Convert a null to a JSON null #
# value_of_null : Value #
let value_of_null : Value = Null in
# Convert a list of JSON values to a JSON array #
# value_of_list : (a -> Value, [a]) -> Value #
let value_of_list : (a -> Value, [a]) -> Value =
fun (func, entries) ->
Array(
List.rev(List.fold_left(
fun l, e-> func(e)::l, [], entries)))
in
# Convert a dictionary of JSON values to a JSON object #
# value_of_object : [(String, Value)] -> Value #
let value_of_object : [(String, Value)] -> Value =
fun entries -> Object(entries)
in
# JSON Decoder #
# A Decoder decodes a JSON value into a Hazel value, or return an Err on failure. #
type Decoder = Value -> Result in
# Decodes a JSON string into a string #
# string_of_value : Decoder #
let string_of_value : Decoder =
fun v ->
case v
| Str(s) => Ok(s)
| _ => Err("Cannot unpack value as a String")
end
in
# Decodes a JSON boolean into a boolean #
# bool_of_value : Decoder #
let bool_of_value : Decoder =
fun v ->
case v
| Boolean(b) => Ok(b)
| _ => Err("Cannot unpack value as a Bool")
end
in
# Decodes a JSON integer into an integer #
# int_of_value : Decoder #
let int_of_value : Decoder =
fun v ->
case v
| Number(n) =>
if floor(n) ==. n then
# n is a whole number #
Ok(floor(n))
else
# n is a floating point #
Err("Cannot unpack a float value as an Int")
| _ => Err("Cannot unpack value as an Int")
end
in
# Decodes a JSON float into a float #
# float_of_value : Decoder #
let float_of_value : Decoder =
fun v ->
case v
| Number(n) => Ok(floor(n))
| _ => Err("Cannot unpack value as a Float")
end
in
# Decodes a JSON null into a null #
# null_of_value : Decoder #
let null_of_value : Decoder =
fun v ->
case v
| Null => Ok(None)
| _ => Err("Cannot unpack value as a None")
end
in
# Parsers #
# Try a bunch of different decoders. #
# This can be useful if the JSON may come in a couple different formats. #
# one_of : [Decoder] -> Decoder #
let one_of : [Decoder] -> Decoder =
fun decoders -> fun v ->
case decoders
| decoder::decoders =>
result_map_err(fun _ -> one_of(decoders)(v), decoder(v))
| [] => Err("one_of failed to decode value")
end
in
# Transform a decoder. #
# map : ((a -> b), Decoder) -> Decoder #
let map : ((a -> b), Decoder) -> Decoder =
fun (func, decoder) -> fun v ->
case decoder(v)
| Err(e) => Err(e)
| Ok(o) => func(o)
in
# Create decoders that depend on previous results. #
# and_then: ((a -> Decoder), Decoder) -> Decoder #
let and_then: ((a -> Decoder), Decoder) -> Decoder =
fun (func, decoder) ->
fun v ->
case decoder(v)
| Err(e) => Err(e)
| Ok(o)=> func(o)(v)
end
in
# Decode a nullable JSON value into a Hazel value. #
# nullable : Decoder -> Decoder #
let nullable : Decoder -> Decoder =
fun decoder ->
one_of([
map(fun s -> Some(s), decoder),
null_of_value
])
in
# Decode a JSON array into a Hazel List. #
# list : Decoder -> Decoder #
let list : Decoder -> Decoder =
fun elem_decoder ->
fun v ->
case v
| Array(arr) =>
case arr
| head::tail =>
case elem_decoder(head)
| Ok(hd) => map(fun tl -> hd::tl, list(elem_decoder))(Array(tail))
| Err(e) => Err(e)
end
| [] => Ok([])
end
| _ => Err("Cannot unpack value as a List")
end
in
# Decode a JSON object into a Hazel dictionary. #
# For now, a dictionary is just a list of key-value pairs #
# dict : Decoder -> Decoder #
let dict : Decoder -> Decoder =
fun value_decoder ->
fun v ->
case v
| Object(pairs) =>
case pairs
| (key, value)::tail =>
case value_decoder(value)
| Ok(hd)=> map(fun tl -> (key, hd)::tl, dict(value_decoder))(Object(tail))
| Err(e) => Err(e)
end
| [] => Ok([])
end
| _ => Err("Cannot unpack value as a dict")
end
in
### List ###
# Add an element to the front of a list. #
# cons: (a, [a]) -> [a] #
let cons: (a, [a]) -> [a] = fun x, xs -> x::xs in
# Determine the length of a list. #
# length: [a] -> Int #
let length: [a] -> Int =
fun xs ->
case xs
| [] => 0
| _::xs => 1 + length(xs) end in
# Extract the first element of a list. #
# hd: [a] -> Option #
let hd: [a] -> Option =
fun l ->
case l
| [] => None
| x::xs => Some(x) end in
# Extract the rest of the list. #
# tl: [a] -> [a] #
let tl: [a] -> [a] =
fun l ->
case l
| [] => []
| x::xs => xs end in
# Determine if a list is empty. #
# is_empty: [a] -> Bool #
let is_empty: [a] -> Bool =
fun xs ->
case xs
| [] => true
| _::_ => false end in
# Return the element at the index. #
# nth: ([a], Int) -> Option #
let nth: ([a], Int) -> Option =
fun xs, n ->
case xs, n
| x::_, 0 => Some(x)
| _::xs, n => nth(xs, n - 1)
| [], _ => None end in
# Reverse a List. #
# rev: [a] -> [a] #
let rev: [a] -> [a] =
fun l ->
let go: ([a], [a]) -> [a] =
fun xs, acc ->
case xs
| [] => acc
| x::xs => go(xs, x::acc) end in
go(l, []) in
# Check if two lists are equal #
# equal: ((a, a) -> Bool, [a], [a]) -> Bool #
let equal: ((a, a) -> Bool, [a], [a]) -> Bool =
fun p, xs, ys ->
case xs, ys
| [], [] => true
| x::xs, y::ys => p(x, y) && equal(p, xs, ys)
| _ => false end
in
# Initialize a list with a given length using an initializer function #
# init: (Int, Int -> a) -> [a] #
let init: (Int, Int -> a) -> [a] =
fun len, f ->
let go: (Int, [a]) -> [a] =
fun idx, xs ->
if idx < len
then go(idx + 1, xs @ [f(idx)])
else xs
in
go(0, [])
in
# Reduce a list from the left. #
# fold_left: ((b, a) -> b, b, [a]) -> b #
let fold_left: ((b, a) -> b, b, [a]) -> b =
fun f, acc, xs ->
case xs
| [] => acc
| hd::tl => fold_left(f, f(acc, hd), tl) end in
# Reduce a list from the right. #
# fold_right: ((a, b) -> b, [a], b) -> b #
let fold_right: ((a, b) -> b, [a], b) -> b =
fun f, xs, acc ->
case xs
| [] => acc
| hd::tl => f(hd, fold_right(f, tl, acc)) end in
# A simplified lambda calculus expression containing variables, lambdas, and applications #
type Exp =
+ Var(String)
+ Lam(String, Exp)
+ Ap(Exp, Exp)
in
# Evaluation can result in either an Exp or an Error #
# Evaluation by substitution #
# eval: Exp -> Result # Command./main \
--grammar-file grammar.gbnf \
-t 10 \
-ngl 64 \
-b 512 \
-m ../models/codellama-34b.Q5_K_M.gguf \
--color -c 3400 \
--temp 0.7 \
--repeat_penalty 1.1 \
-n -1 \
-f prompt.txt |
After looking into the code, I think there's a seemingly obvious and much more simple way to optimize grammar sampling even without threading. Right now, it manually checks all token candidates and removes any candidates that would violate the grammar. It would be much more effective and simple to simply sample the normal way, check if the chosen token violates the grammar before proceeding with it, and if it violates the grammar, it should revert to the current behavior that 'forces' the grammar. @ejones Any suggestions for how I would go about implementing a solution? |
I'm doing some investigation. I think the easiest way to do this without refactors to the grammar itself is by running a check to the existing grammar function with only the single candidate in sampling.cpp; if it's correct, we proceed. If it's wrong, we restart sampling, this time running:
Before the rep pen or any other modifications are made to the logits. |
I have made a pull request which should reduce the number of checks necessary to 1 for most tokens instead of all 32,000 tokens in the vocabulary. I have not evaluated whether or not it is actually faster yet, but I'm guessing that avoiding thousands of UTF8 decoding steps for most tokens would improve performance. |
@AlienKevin thanks for investigating! Yeah it's a simple top-down backtracking parser so it can be exponential in the worst case. It works best for grammars with little or no ambiguity or nondeterminism. A deterministic grammar should maintain a constant stack count. This isn't obvious though and we could probably do a better job signaling this to the user. @kalomaze looks great, commented on your PR as well. |
Grammar processing appears to be quite slow (again?): #4306 (comment) |
No issue on my end |
I've noticed it varies widely with respect to prompt complexity. My JSON schema -> grammar contains three levels of object-arrays and if I ask for a shorter output it completes reasonably quickly with the conforming schema and runs at a consistently high level of CPU utilization. But if I ask for an output that is about ten times longer, for the exact same schema, I notice the resource utilization (CPU mainly) becomes highly variable and rarely sustains max utilization. The overall inference time gets long enough that it's not worth waiting for the task to complete (30+ minutes) whereas in contrast the exact same prompt will run for about 7 minutes consistently with the grammar/schema removed. If I need to just post a full repro I'm happy to link a Gist. |
This issue was closed because it has been inactive for 14 days since being marked as stale. |
Just to close the loop on my previous comment-- I continued experimenting with this feature on a wide variety of cases and ultimately concluded that the performance variance is too large for production use, even on fast GPUs such as the A100s. I would, however, very much like to see this feature perform consistently enough for production as it is otherwise very useful. I am happy to help with testing or reproduction of test cases if anyone decides to work on this. |
I've been digging into this lately, and I've been using the integration tests in #6472 to do some crude performance profiling. I've definitely seen the sort of dramatic stack expansion that @AlienKevin is talking about. I think there are many causes, but one that I've been digging into is how easy it is for alternate grammars to dramatically inflate the stack. For instance, imagine a grammar that says:
This is a rather extreme example, but hopefully it illustrates the point. If you have an input string that looks like "1a2a3a4a5", then at the first character, there is only 1 stack.
If we change our input string to "1aa2aa3aa4aa5", then it gets even worse, because it permutates a bit:
On my laptop (Macbook M1, 32gb RAM), this noticeably lags the machine, and it hitches to execute even this (relatively) short grammar and validation string. I've done tests with much larger grammars that push MUCH larger stack sizes, and the ambiguity can really drag things down -- even to the point of memory exhaustion. Currently I'm tackling some left-recursion issues that can cause segfaults ( #6492 ), but after I get done with that (or give up), then I'm going to tackle these performance issues related to ambiguity. I'm not entirely sure, but I think that there should be a viable algorithm to prune near-redundant stacks if their outcomes would be equivalent. I.E., if we've got four potential "a"'s to match, and only one "a", then it doesn't matter which one we choose -- the others can be discarded once we get to the next token, so long as they all converge onto the same spot in the same rule. This isn't the only performance-related issue that grammars are seeing, but I believe that these massive exponential growths in the stack size is one of biggest opportunities for optimization. Anyways, that's where I'm at on things -- will keep y'all posted. |
I'm not very familiar with the current setup of our CI performance profilers -- if I were to make improvements to the grammar engine, would those speed improvements show up in our current bank of benchmarks? |
We don't have benchmarks for this yet. You will have to do some custom profiling to determine how the performance changes. With time, we can attempt to add some sort of speed information about the grammar to the CI or at least some local tools. |
@HanClinto One possibly big source of such explosive repetitions is JSON grammars w/ minItems/maxItems (or w/ JSON string regexp patterns such as
I'm working on fixing the JSON grammar conversion to do this (e.g. 375f85d), hope to send a PR soon. I'll probably update the GBNF doc w/ performance caveats in the same PR. |
@ochafik That indeed is a massive improvement! Testing your grammar against 1a2a3a4a5 gives:
And testing against the worse case of 1aa2aa3aa4aa5 gives:
Indeed, that's a massive savings. Nice speedup! It would still be nice to find ways to speed up the grammar tree navigation even with more poorly written grammars, but improving the quality of the grammars in this way is a huge help. |
@HanClinto I'd be inclined to detect some easily rewritable grammar cases on the fly and explode when the grammar becomes too combinatorial (w/ a link to a "performance tips" section of the GBNF wiki page), either with a cap on stack size or some builtin timeout maybe? Maybe also some new features like numbered repetition operators |
Fwiw, I've wanted bounded repetitions a few times with this grammar; recursive white space sometimes let's the model spin forever. It would also be great to see some stats on the grammar as well, either after running or after desugaring so that we can optimize grammars. |
This sounds really good. Detecting and gracefully exiting is the first step. Worst case is that things explode in an infinite loop (as is the case with left-recursion, as in #6492), and implementing a max stack size of 1024 or something should at least give a reasonable approach.
Yeah, I think that could be pretty reasonable. I'm also wondering about
Makes me wonder if there is a better way to do this (more akin to what you wrote above) but I'm still relatively new to grammar engines. I'm learning a lot about grammar parsing as part of this exercise -- I never read the dragon book, but I'm wondering if I should order myself a copy as part of this work. :)
Similar to what @ochafik is saying, I wonder if a good first step would be to do an optimization step where -- after expanding grammars -- that adjacent optional tokens are always collapsed into something more efficient. I.E., |
As a user of the I'm building out an agent flow, where I call the LLM an extra time before each user message to extract any function/tool call, detect the intent, etc, which is done using a json schema like And this extra inference step takes up to 5sec on my 3090, even though it's hardly generating any tokens (in fact, I limit it to 20), presumably because of the grammar. After the function calling stage, the actual agent response from the LLM is very speedy, usually sub-1-second depending on length. So it's taking 5x longer to generate only a few tokens for function calling, compared to actually writing out a long response message. I previously used TabbyAPI for this, and it handled the grammar extremely fast - sub-200ms usually, compared to 5sec in llama.cpp |
@andysalerno Would you be able to share a real-world grammar to test one of the fixes that are in flight? (e.g. #7424 ) |
sure, I am using the {
"title": "AnswerFormat",
"type": "object",
"properties": {
"last_user_message_intent": {
"type": "string"
},
"function_name": {
"type": "string"
},
"invocation": {
"type": "string"
}
},
"required": [
"last_user_message_intent",
"function_name",
"invocation"
]
} I can respond shortly with some measurements showing how long this takes compared to a non-grammar generation with similar length. |
here is the verbose "timings": {
"prompt_n": 592,
"prompt_ms": 708.655,
"prompt_per_token_ms": 1.1970523648648648,
"prompt_per_second": 835.3853426561585,
"predicted_n": 37,
"predicted_ms": 4103.24,
"predicted_per_token_ms": 110.89837837837837,
"predicted_per_second": 9.017264405689163
} note how it took 4sec to predict 37 tokens, and the and here's the "timings": {
"prompt_n": 140,
"prompt_ms": 280.92,
"prompt_per_token_ms": 2.0065714285714287,
"prompt_per_second": 498.3625231382599,
"predicted_n": 11,
"predicted_ms": 292.259,
"predicted_per_token_ms": 26.569000000000003,
"predicted_per_second": 37.637848620572846
} Note it only took ~300ms for 11 tokens, with a |
@andysalerno Thanks! Added your grammar to these benchmarks, it's one that seems will benefit from both #7424 ( Show timingsgit checkout https://github.com/ochafik/llama.cpp llama.cpp-ochafik
cd llama.cpp-ochafik
for branch in master grammar-fast grammar-resampling grammar-fast-resampling ; do
echo "#"
echo "# $branch"
echo "#"
git checkout $branch
make clean
make -j LLAMA_CURL=1 main
./main \
-mu https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf \
-j '{"title": "AnswerFormat", "type": "object", "properties": {"last_user_message_intent": {"type": "string" }, "function_name": {"type": "string" }, "invocation": {"type": "string" }}, "required": [ "last_user_message_intent", "function_name", "invocation"]}' \
-p "Describe a function call of a tool in JSON format after a reminder of the last user message intent." \
--seed 12345 --no-display-prompt 2>&1 | \
grep "llama_print_timings"
done
|
Example of slow grammar :
|
…anov#9833 Grammar memo Co-Authored-By: Clarissa Miranda <80654285+clarissamiranda@users.noreply.github.com>
ggerganov#9833" This reverts commit 4cbf5c392af62252a69e17143e8a81d771ca6f8a.
This issue was closed because it has been inactive for 14 days since being marked as stale. |
Is it possible to reopen this? This issue needs fixing. |
To help me understand the particulars, how big of a concern / obstacle is this for you and your workflows? I haven't worked on this problem in a while, but I would like to get back to it at some point, and I keep mulling the issue over in my mind. Right now, the biggest obstacle that I see is that the number of stack permutations in the case of ambiguous grammars explodes so high in some cases, that it's very difficult to manage. Some options that we could consider:
I'm very outside of my comfort zone of expertise, and I don't know how viable / attractive any of these paths are, but I'd love to hear others' thoughts on it. |
I just did a very draft PR #10224 that integrates an alternative grammar library. You can give it a spin and see if it helps! |
There have been a few reports where the grammar sampling can significantly degrade the performance.
It would be nice to profile and optimize the implementation - there should be room for improvements.
Already on-going efforts:
reserve
space indecode_utf8
#4210llama_token_to_piece
when sampling grammars #4213Probably worth looking in multi-threading the implementation as well.
The text was updated successfully, but these errors were encountered: