Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : speed-up grammar sampling #4218

Open
ggerganov opened this issue Nov 25, 2023 · 40 comments
Open

llama : speed-up grammar sampling #4218

ggerganov opened this issue Nov 25, 2023 · 40 comments
Labels
performance Speed related topics refactoring Refactoring

Comments

@ggerganov
Copy link
Owner

There have been a few reports where the grammar sampling can significantly degrade the performance.
It would be nice to profile and optimize the implementation - there should be room for improvements.

Already on-going efforts:

Probably worth looking in multi-threading the implementation as well.

@ggerganov ggerganov added performance Speed related topics refactoring Refactoring labels Nov 25, 2023
@ExtReMLapin
Copy link
Contributor

ExtReMLapin commented Nov 27, 2023

#3980 and this suggestion might also help a bit #3980 (comment)

I would have expected the compiler to optimize it straight away 🤷🏻

@gottlike
Copy link

Would an integration of Outlines help? Like they are doing with vLLM: dottxt-ai/outlines#163

@ggerganov
Copy link
Owner Author

@ExtReMLapin This copy is used only in the speculative example. Even if it helps there, it won't have any effect on the general use case. Still, a PR is welcome

@gottlike An efficient low-level solution as the one we currently have seems like a better approach to me.

@shroominic
Copy link

I noticed that inference gets at some point exponentially slower when there are a lot of deeply nested, but open grammars. With open I mean a lot of different possibilities. As example I am trying to work on PydanticModel -> JsonSchema -> Grammar and when the model outputs a list of nested subobjects this effect comes when the list is long and at some point it gets stuck.

@ExtReMLapin
Copy link
Contributor

@shroominic on my end it just gets slower the longer it is in printing the json array, no nested objects.

@AlienKevin
Copy link
Contributor

I found similar exponential slowdown as mentioned by @shroominic for my use case, which is to generate code in a language similar to OCaml. The speed of generation was very fast at the first 200 tokens but increased to more than 400 seconds per token as I approach 300 tokens.

I plotted the grammar stack size and duration per token over time and found stack size to be the main factor in the slow down. The number of grammar stacks can go up 800K for my grammar. I'm not very familiar with the grammar sampling algorithm used in llama.cpp but I suspect it's exponential in the length of the parsed string. Polynomially bounded parsing algorithms like the Earley parser might help avoid the exponential blowup.

Grammar
root ::= [ \t\n]* exp

ws ::= [ \t\n]+
w ::= [ \t]*

comment ::= "#" [^#]* "#" [ \t]+ [\n]? [ \t]*

### Expressions

exp ::= comment* sequence-exp

sequence-exp ::= tuple-exp (w ";" ws tuple-exp)*

tuple-exp ::= cons-exp (w "," ws cons-exp)*

cons-exp ::= binary-exp (w "::" w binary-exp)*

binary-exp ::= unary-exp (ws binary-op ws unary-exp)*

unary-exp ::= unary-op* function-app-exp

function-app-exp ::= primary-exp (w "(" w exp w ")" w)*

primary-exp ::= bool |
    integer |
    float |
    string |
    variable |
    "()" |
    "[]" |
    constructor |
    constructor-app |
    parenthesized-exp |
    list-exp |
    let-exp |
    if-exp |
    case-exp |
    test-exp |
    type-alias |
    fun

constructor-app ::= constructor "(" w exp w ")"
parenthesized-exp ::= "(" w exp w ")"
list-exp ::= "[" exp ("," ws exp)* "]"
let-exp ::= "let" ws pat ws "=" ws exp ws "in" ws exp
if-exp ::= "if" ws exp ws "then" ws exp ws "else" ws exp
case-exp ::= "case" ws exp (ws "|" ws pat ws "=>" ws exp)+ ws "end"
test-exp ::= "test" ws exp ws "end"
type-alias ::= "type" ws constructor ws "=" ws typ ws "in" ws exp
fun ::= "fun" ws pat ws "->" ws exp

type-variable ::= [a-z][A-Za-z0-9_]*
constructor ::= [A-Z][A-Za-z0-9_]*
variable ::= ([_a-bdg-hj-kn-qu-z][A-Za-z0-9_.]*)|(("s" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("st" ([.0-9A-Z_a-qs-z][A-Za-z0-9_.]*)?)|("str" ([.0-9A-Z_a-tv-z][A-Za-z0-9_.]*)?)|("stru" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("struc" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("struct" [A-Za-z0-9_.]+)|("c" ([.0-9A-Z_b-z][A-Za-z0-9_.]*)?)|("ca" ([.0-9A-Z_a-rt-z][A-Za-z0-9_.]*)?)|("cas" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("case" [A-Za-z0-9_.]+)|("i" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("in" [A-Za-z0-9_.]+)|("r" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("re" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("rec" [A-Za-z0-9_.]+)|("t" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("te" ([.0-9A-Z_a-rt-z][A-Za-z0-9_.]*)?)|("tes" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("test" [A-Za-z0-9_.]+)|("l" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("le" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("let" [A-Za-z0-9_.]+)|("m" ([.0-9A-Z_b-z][A-Za-z0-9_.]*)?)|("ma" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("mat" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("matc" ([.0-9A-Z_a-gi-z][A-Za-z0-9_.]*)?)|("match" [A-Za-z0-9_.]+)|("f" ([.0-9A-Z_a-tv-z][A-Za-z0-9_.]*)?)|("fu" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("fun" [A-Za-z0-9_.]+)|("e" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("en" ([.0-9A-Z_a-ce-z][A-Za-z0-9_.]*)?)|("end" [A-Za-z0-9_.]+))
bool ::= "true" | "false"
integer ::= [0-9]+
float ::= [0-9]* "." [0-9]+
string ::= "\"" [^"]* "\""

unary-op ::= "-" | "!"
binary-op-int ::= "+" | "-" | "*" | "/" | "<" | ">" | "<=" | ">=" | "==" | "!="
binary-op-float ::= "+." | "-." | "*." | "/." | "<." | ">." | "<=." | ">=." | "==." | "!=."
binary-op-string ::= "$==" | "@"
binary-op-logic ::= "&&"
binary-op ::= binary-op-int | binary-op-float | binary-op-string | binary-op-logic

### Patterns

pat ::= type-ascription-pat

type-ascription-pat ::= tuple-pat (w ":" ws typ)*

tuple-pat ::= cons-pat (w "," ws cons-pat)*

cons-pat ::= primary-pat (w "::" w primary-pat)*

primary-pat ::=
    bool |
    integer |
    float |
    string |
    variable |
    "()" |
    "[]" |
    "_" |
    constructor |
    constructor-app-pat |
    parenthesized-pat |
    list-pat

constructor-app-pat ::= constructor "(" w pat w ")"
parenthesized-pat ::= "(" w pat w ")"
list-pat ::= "[" pat (w "," ws pat)* "]"

### Types

typ ::= arrow-typ

arrow-typ ::= tuple-typ (ws "->" ws tuple-typ)*

tuple-typ ::= primary-typ (w "," ws primary-typ)*

primary-typ ::=
    "Unit" |
    "Int" |
    "Float" |
    "Bool" |
    "String" |
    type-variable |
    constructor |
    constructor-def (ws "+" ws constructor-def)+ |
    parenthesized-typ |
    list-typ

parenthesized-typ ::= "(" w typ w ")"
list-typ ::= "[" w typ w "]"
constructor-def ::= constructor | constructor "(" w typ w ")"
Prompt
### Option ###

# Represent values that may or may not exist. #
type Option =
  + Some(?)
  + None
in

# Compare if two Options are equal #
# equal: ((?, ?) -> Bool) -> (Option, Option) -> Bool #
let equal: ((?, ?) -> Bool) -> (Option, Option) -> Bool =
    fun eq, os ->
        case os
        | Some(x), Some(y) => eq(x, y)
        | None, None => True
        | _, _ => False
        end
in

### Result ###

# A Result is either Ok meaning the computation succeeded, #
# or it is an Err meaning that there was some failure. #
type Result =
  + Ok(a)
  + Err(b)
in

# Compare if two Results are equal #
# equal: ((a, a) -> Bool) -> (Result, Result) -> Bool #
let equal: ((a, a) -> Bool) -> (Result, Result) -> Bool =
    fun eq, rs ->
        case rs
        | Ok(e1), Ok(e2) => eq(e1, e2)
        | Error(e1), Error(e2) => e1 $== e2
        | _ => false
        end
in


### JSON ###
# This module helps you convert between Hazel values and JSON values. #

# A JSON value type #
type Value =
  + Object([(String, Value)])
  + Array([Value])
  + Str(String)
  + Number(Float)
  + Boolean(Bool)
  + Null 
in

# Check if two JSON values are equal #
# equal : (Value,Value) -> Bool #
let equal : (Value,Value) -> Bool =
fun a, b ->
    case (a, b)
    | Object(o1), Object(o2) => List.equal(
        fun (s1, v1), (s2, v2) ->
            s1 $== s2 && equal(v1, v2), o1, o2)
    | Array(a1), Array(a2) => List.equal(equal, a1, a2)
    | Str(s1), Str(s2) => s1 $== s2
    | Number(n1), Number(n2) => n1 ==. n2
    | Boolean(b1), Boolean(b2) => if b1 then b2 else !b2
    | Null, Null => true
    | _ => false
  end 
in

# JSON Encoder #

# Convert a string to a JSON string #
# value_of_string : String -> Value #
let value_of_string : String -> Value =
    fun s -> Str(s) 
in

# Convert an integer to a JSON integer #
# value_of_int : Int -> Value #
let value_of_int : Int -> Value =
    fun i -> Number(float_of_int(i)) 
in

# Convert a float to a JSON float #
# value_of_float : Float -> Value #
let value_of_float : Float -> Value =
    fun f -> Number(f) 
in

# Convert a boolean to a JSON boolean #
# value_of_bool : Bool -> Value #
let value_of_bool : Bool -> Value =
    fun b -> if b then Boolean(true) else Boolean(false)
in

# Convert a null to a JSON null #
# value_of_null : Value #
let value_of_null : Value = Null in

# Convert a list of JSON values to a JSON array #
# value_of_list : (a -> Value, [a]) -> Value #
let value_of_list : (a -> Value, [a]) -> Value =
  fun (func, entries) ->
    Array(
      List.rev(List.fold_left(
        fun l, e-> func(e)::l, [], entries)))
in

# Convert a dictionary of JSON values to a JSON object #
# value_of_object : [(String, Value)] -> Value #
let value_of_object : [(String, Value)] -> Value =
    fun entries -> Object(entries)
in

# JSON Decoder #
# A Decoder decodes a JSON value into a Hazel value, or return an Err on failure. #
type Decoder = Value -> Result in

# Decodes a JSON string into a string #
# string_of_value : Decoder #
let string_of_value : Decoder =
    fun v ->
        case v
        | Str(s) => Ok(s)
        | _ => Err("Cannot unpack value as a String")
        end
in

# Decodes a JSON boolean into a boolean #
# bool_of_value : Decoder #
let bool_of_value : Decoder =
  fun v ->
    case v
    | Boolean(b) => Ok(b)
    | _ => Err("Cannot unpack value as a Bool")
    end
in

# Decodes a JSON integer into an integer #
# int_of_value : Decoder #
let int_of_value : Decoder =
    fun v ->
        case v
        | Number(n) =>
            if floor(n) ==. n then
                # n is a whole number #
                Ok(floor(n))
            else
                # n is a floating point #
                Err("Cannot unpack a float value as an Int")
        | _ => Err("Cannot unpack value as an Int") 
        end
in

# Decodes a JSON float into a float #
# float_of_value : Decoder #
let float_of_value : Decoder =
fun v ->
    case v
    | Number(n) => Ok(floor(n))
    | _ => Err("Cannot unpack value as a Float")
    end
in

# Decodes a JSON null into a null #
# null_of_value : Decoder #
let null_of_value : Decoder =
  fun v ->
    case v
    | Null => Ok(None)
    | _ => Err("Cannot unpack value as a None")
    end
in

# Parsers #
# Try a bunch of different decoders. #
# This can be useful if the JSON may come in a couple different formats. #
# one_of : [Decoder] -> Decoder #
let one_of : [Decoder] -> Decoder =
    fun decoders -> fun v ->
        case decoders
        | decoder::decoders =>
            result_map_err(fun _ -> one_of(decoders)(v), decoder(v))
        | [] => Err("one_of failed to decode value")
        end
    in

# Transform a decoder. #
# map : ((a -> b), Decoder) -> Decoder #
let map : ((a -> b), Decoder) -> Decoder =
    fun (func, decoder) -> fun v ->
        case decoder(v)
        | Err(e) => Err(e)
        | Ok(o) => func(o)
in

# Create decoders that depend on previous results. #
# and_then: ((a -> Decoder), Decoder) -> Decoder #
let and_then: ((a -> Decoder), Decoder) -> Decoder =
    fun (func, decoder) ->
        fun v ->
            case decoder(v) 
            | Err(e) => Err(e)
            | Ok(o)=> func(o)(v)
            end
in

# Decode a nullable JSON value into a Hazel value. #
# nullable : Decoder -> Decoder #
let nullable : Decoder -> Decoder =
    fun decoder ->
        one_of([
            map(fun s -> Some(s), decoder),
            null_of_value
        ])
in

# Decode a JSON array into a Hazel List. #
# list : Decoder -> Decoder #
let list : Decoder -> Decoder =
    fun elem_decoder ->
    fun v ->
        case v 
        | Array(arr) => 
    case arr 
    | head::tail =>
    case elem_decoder(head) 
    | Ok(hd) => map(fun tl -> hd::tl, list(elem_decoder))(Array(tail))
    | Err(e) => Err(e)
        end
        | [] => Ok([])
        end
    | _ => Err("Cannot unpack value as a List")
    end
in

# Decode a JSON object into a Hazel dictionary. #
# For now, a dictionary is just a list of key-value pairs #
# dict : Decoder -> Decoder #
let dict : Decoder -> Decoder =
  fun value_decoder ->
    fun v ->
        case v 
        | Object(pairs) =>
            case pairs
            | (key, value)::tail =>
                case value_decoder(value) 
                | Ok(hd)=> map(fun tl -> (key, hd)::tl, dict(value_decoder))(Object(tail))
                | Err(e) => Err(e)
                end
            | [] => Ok([])
            end
        | _ => Err("Cannot unpack value as a dict")
        end
in


### List ###

# Add an element to the front of a list. #
# cons: (a, [a]) -> [a] #
let cons: (a, [a]) -> [a] = fun x, xs -> x::xs in

# Determine the length of a list. #
# length: [a] -> Int #
let length: [a] -> Int =
  fun xs ->
    case xs
    | [] => 0
    | _::xs => 1 + length(xs) end in

# Extract the first element of a list. #
# hd: [a] -> Option #
let hd: [a] -> Option =
  fun l ->
    case l
    | [] => None
    | x::xs => Some(x) end in

# Extract the rest of the list. #
# tl: [a] -> [a] #
let tl: [a] -> [a] =
  fun l ->
    case l
    | [] => []
    | x::xs => xs end in

# Determine if a list is empty. #
# is_empty: [a] -> Bool #
let is_empty: [a] -> Bool =
  fun xs ->
    case xs
    | [] => true
    | _::_ => false end in

# Return the element at the index. #
# nth: ([a], Int) -> Option #
let nth: ([a], Int) -> Option =
  fun xs, n ->
    case xs, n
    | x::_, 0 => Some(x)
    | _::xs, n => nth(xs, n - 1)
    | [], _ => None end in

# Reverse a List. #
# rev: [a] -> [a] #
let rev: [a] -> [a] =
fun l -> 
let go: ([a], [a]) -> [a] =
  fun xs, acc -> 
    case xs 
    | [] => acc 
    | x::xs => go(xs, x::acc) end in
go(l, []) in

# Check if two lists are equal #
# equal: ((a, a) -> Bool, [a], [a]) -> Bool #
let equal: ((a, a) -> Bool, [a], [a]) -> Bool =
    fun p, xs, ys ->
    case xs, ys
    | [], [] => true
    | x::xs, y::ys => p(x, y) && equal(p, xs, ys)
    | _ => false end
in

# Initialize a list with a given length using an initializer function #
# init: (Int, Int -> a) -> [a] #
let init: (Int, Int -> a) -> [a] =
    fun len, f ->
        let go: (Int, [a]) -> [a] =
        fun idx, xs ->
            if idx < len 
            then go(idx + 1, xs @ [f(idx)])   
            else xs
        in
        go(0, [])
in

# Reduce a list from the left. #
# fold_left: ((b, a) -> b, b, [a]) -> b #
let fold_left: ((b, a) -> b, b, [a]) -> b =
  fun f, acc, xs ->
    case xs 
    | [] => acc
    | hd::tl => fold_left(f, f(acc, hd), tl) end in

# Reduce a list from the right. #
# fold_right: ((a, b) -> b, [a], b) -> b #
let fold_right: ((a, b) -> b, [a], b) -> b =
  fun f, xs, acc ->
    case xs
    | [] => acc
    | hd::tl => f(hd, fold_right(f, tl, acc)) end in

# A simplified lambda calculus expression containing variables, lambdas, and applications #
type Exp =
  + Var(String)
  + Lam(String, Exp)
  + Ap(Exp, Exp)
in
# Evaluation can result in either an Exp or an Error #
# Evaluation by substitution #
# eval: Exp -> Result #
Command
./main \
    --grammar-file grammar.gbnf \
    -t 10 \
    -ngl 64 \
    -b 512 \
    -m ../models/codellama-34b.Q5_K_M.gguf \
    --color -c 3400 \
    --temp 0.7 \
    --repeat_penalty 1.1 \
    -n -1 \
    -f prompt.txt

exp_eval_grammar_stack_size

@kalomaze
Copy link
Contributor

kalomaze commented Dec 3, 2023

I found similar exponential slowdown as mentioned by @shroominic for my use case, which is to generate code in a language similar to OCaml. The speed of generation was very fast at the first 200 tokens but increased to more than 400 seconds per token as I approach 300 tokens.

I plotted the grammar stack size and duration per token over time and found stack size to be the main factor in the slow down. The number of grammar stacks can go up 800K for my grammar. I'm not very familiar with the grammar sampling algorithm used in llama.cpp but I suspect it's exponential in the length of the parsed string. Polynomially bounded parsing algorithms like the Earley parser might help avoid the exponential blowup.

After looking into the code, I think there's a seemingly obvious and much more simple way to optimize grammar sampling even without threading.

Right now, it manually checks all token candidates and removes any candidates that would violate the grammar.

It would be much more effective and simple to simply sample the normal way, check if the chosen token violates the grammar before proceeding with it, and if it violates the grammar, it should revert to the current behavior that 'forces' the grammar.
Right now, it's 'forcing' the sample set to always match before picking. This is resource intensive for large vocabulary models and is highly unnecessary as the model will naturally adopt the grammar most of the time with typical sampler settings (especially with Min P / low temp), and the new behavior would only need to run the full grammar calculations some of the time.

@ejones Any suggestions for how I would go about implementing a solution?

@kalomaze
Copy link
Contributor

kalomaze commented Dec 3, 2023

I'm doing some investigation. I think the easiest way to do this without refactors to the grammar itself is by running a check to the existing grammar function with only the single candidate in sampling.cpp; if it's correct, we proceed. If it's wrong, we restart sampling, this time running:

    if (ctx_sampling->grammar != NULL) {
        llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
    }

Before the rep pen or any other modifications are made to the logits.
I plan to achieve this by making a copy of the initial logits for the "2nd pass".

@kalomaze
Copy link
Contributor

kalomaze commented Dec 3, 2023

There have been a few reports where the grammar sampling can significantly degrade the performance. It would be nice to profile and optimize the implementation - there should be room for improvements.

Already on-going efforts:

Probably worth looking in multi-threading the implementation as well.

#4306

I have made a pull request which should reduce the number of checks necessary to 1 for most tokens instead of all 32,000 tokens in the vocabulary. I have not evaluated whether or not it is actually faster yet, but I'm guessing that avoiding thousands of UTF8 decoding steps for most tokens would improve performance.

@ejones
Copy link
Collaborator

ejones commented Dec 6, 2023

@AlienKevin thanks for investigating! Yeah it's a simple top-down backtracking parser so it can be exponential in the worst case. It works best for grammars with little or no ambiguity or nondeterminism. A deterministic grammar should maintain a constant stack count. This isn't obvious though and we could probably do a better job signaling this to the user.

@kalomaze looks great, commented on your PR as well.

@txbm
Copy link

txbm commented Feb 15, 2024

Grammar processing appears to be quite slow (again?): #4306 (comment)

@ExtReMLapin
Copy link
Contributor

No issue on my end

@txbm
Copy link

txbm commented Feb 15, 2024

I've noticed it varies widely with respect to prompt complexity. My JSON schema -> grammar contains three levels of object-arrays and if I ask for a shorter output it completes reasonably quickly with the conforming schema and runs at a consistently high level of CPU utilization.

But if I ask for an output that is about ten times longer, for the exact same schema, I notice the resource utilization (CPU mainly) becomes highly variable and rarely sustains max utilization. The overall inference time gets long enough that it's not worth waiting for the task to complete (30+ minutes) whereas in contrast the exact same prompt will run for about 7 minutes consistently with the grammar/schema removed. If I need to just post a full repro I'm happy to link a Gist.

@github-actions github-actions bot added the stale label Mar 19, 2024
Copy link
Contributor

github-actions bot commented Apr 3, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Apr 3, 2024
@ggerganov ggerganov removed the stale label Apr 4, 2024
@ggerganov ggerganov reopened this Apr 4, 2024
@txbm
Copy link

txbm commented Apr 4, 2024

Just to close the loop on my previous comment-- I continued experimenting with this feature on a wide variety of cases and ultimately concluded that the performance variance is too large for production use, even on fast GPUs such as the A100s.

I would, however, very much like to see this feature perform consistently enough for production as it is otherwise very useful. I am happy to help with testing or reproduction of test cases if anyone decides to work on this.

@HanClinto
Copy link
Collaborator

HanClinto commented Apr 4, 2024

I've been digging into this lately, and I've been using the integration tests in #6472 to do some crude performance profiling.

I've definitely seen the sort of dramatic stack expansion that @AlienKevin is talking about. I think there are many causes, but one that I've been digging into is how easy it is for alternate grammars to dramatically inflate the stack. For instance, imagine a grammar that says:

root ::= [0-9] ("a"? "a"? "a"? "a"? "a"? [0-9])*

This is a rather extreme example, but hopefully it illustrates the point.

If you have an input string that looks like "1a2a3a4a5", then at the first character, there is only 1 stack.
But it doesn't know if the first 'a' matches the first a? in the grammar, or one of the later ones -- so it needs to track all 5 possibilities. We now have 7 stacks. Which then grows to 15, 35, 75, etc.

Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 7
Parsing character 2 ('2'), stack size 15
Parsing character 3 ('a'), stack size 35
Parsing character 4 ('3'), stack size 75
Parsing character 5 ('a'), stack size 175
Parsing character 6 ('4'), stack size 375
Parsing character 7 ('a'), stack size 875
Parsing character 8 ('5'), stack size 1875

If we change our input string to "1aa2aa3aa4aa5", then it gets even worse, because it permutates a bit:

Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 7
Parsing character 2 ('a'), stack size 15
Parsing character 3 ('2'), stack size 20
Parsing character 4 ('a'), stack size 70
Parsing character 5 ('a'), stack size 150
Parsing character 6 ('3'), stack size 200
Parsing character 7 ('a'), stack size 700
Parsing character 8 ('a'), stack size 1500
Parsing character 9 ('4'), stack size 2000
Parsing character 10 ('a'), stack size 7000
Parsing character 11 ('a'), stack size 15000
Parsing character 12 ('5'), stack size 20000

On my laptop (Macbook M1, 32gb RAM), this noticeably lags the machine, and it hitches to execute even this (relatively) short grammar and validation string. I've done tests with much larger grammars that push MUCH larger stack sizes, and the ambiguity can really drag things down -- even to the point of memory exhaustion.

Currently I'm tackling some left-recursion issues that can cause segfaults ( #6492 ), but after I get done with that (or give up), then I'm going to tackle these performance issues related to ambiguity. I'm not entirely sure, but I think that there should be a viable algorithm to prune near-redundant stacks if their outcomes would be equivalent. I.E., if we've got four potential "a"'s to match, and only one "a", then it doesn't matter which one we choose -- the others can be discarded once we get to the next token, so long as they all converge onto the same spot in the same rule.

This isn't the only performance-related issue that grammars are seeing, but I believe that these massive exponential growths in the stack size is one of biggest opportunities for optimization.

Anyways, that's where I'm at on things -- will keep y'all posted.

@HanClinto
Copy link
Collaborator

HanClinto commented Apr 4, 2024

I'm not very familiar with the current setup of our CI performance profilers -- if I were to make improvements to the grammar engine, would those speed improvements show up in our current bank of benchmarks?

@ggerganov
Copy link
Owner Author

if I were to make improvements to the grammar engine, would those speed improvements show up in our current bank of benchmarks?

We don't have benchmarks for this yet. You will have to do some custom profiling to determine how the performance changes. With time, we can attempt to add some sort of speed information about the grammar to the CI or at least some local tools.

@ochafik
Copy link
Collaborator

ochafik commented Apr 8, 2024

how easy it is for alternate grammars to dramatically inflate the stack. For instance, imagine a grammar that says:
root ::= [0-9] ("a"? "a"? "a"? "a"? "a"? [0-9])*

@HanClinto One possibly big source of such explosive repetitions is JSON grammars w/ minItems/maxItems (or w/ JSON string regexp patterns such as {"type": "string", "pattern": "a{3,10}"}). The easy workaround rn is to rewrite the rule as:

root ::= [0-9] (("a" ("a" ("a" ("a" ("a")?)?)?)?)? [0-9])*

I'm working on fixing the JSON grammar conversion to do this (e.g. 375f85d), hope to send a PR soon. I'll probably update the GBNF doc w/ performance caveats in the same PR.

@HanClinto
Copy link
Collaborator

@ochafik That indeed is a massive improvement! Testing your grammar against 1a2a3a4a5 gives:

Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 3
Parsing character 2 ('2'), stack size 2
Parsing character 3 ('a'), stack size 3
Parsing character 4 ('3'), stack size 2
Parsing character 5 ('a'), stack size 3
Parsing character 6 ('4'), stack size 2
Parsing character 7 ('a'), stack size 3
Parsing character 8 ('5'), stack size 2

And testing against the worse case of 1aa2aa3aa4aa5 gives:

Parsing character 0 ('1'), stack size 1
Parsing character 1 ('a'), stack size 3
Parsing character 2 ('a'), stack size 2
Parsing character 3 ('2'), stack size 2
Parsing character 4 ('a'), stack size 3
Parsing character 5 ('a'), stack size 2
Parsing character 6 ('3'), stack size 2
Parsing character 7 ('a'), stack size 3
Parsing character 8 ('a'), stack size 2
Parsing character 9 ('4'), stack size 2
Parsing character 10 ('a'), stack size 3
Parsing character 11 ('a'), stack size 2
Parsing character 12 ('5'), stack size 2

Indeed, that's a massive savings. Nice speedup!

It would still be nice to find ways to speed up the grammar tree navigation even with more poorly written grammars, but improving the quality of the grammars in this way is a huge help.

@ochafik
Copy link
Collaborator

ochafik commented Apr 8, 2024

ways to speed up the grammar tree navigation even with more poorly written grammars

@HanClinto I'd be inclined to detect some easily rewritable grammar cases on the fly and explode when the grammar becomes too combinatorial (w/ a link to a "performance tips" section of the GBNF wiki page), either with a cap on stack size or some builtin timeout maybe?

Maybe also some new features like numbered repetition operators "a"{,5} (desugared as the grammar above) and maybe other regexp derived syntax features (reluctant/eager modifiers?) could make it easier to write efficient grammars.

@o1lo01ol1o
Copy link

o1lo01ol1o commented Apr 8, 2024

Fwiw, I've wanted bounded repetitions a few times with this grammar; recursive white space sometimes let's the model spin forever.

It would also be great to see some stats on the grammar as well, either after running or after desugaring so that we can optimize grammars.

@HanClinto
Copy link
Collaborator

@HanClinto I'd be inclined to detect some easily rewritable grammar cases on the fly and explode when the grammar becomes too combinatorial (w/ a link to a "performance tips" section of the GBNF wiki page), either with a cap on stack size or some builtin timeout maybe?

This sounds really good. Detecting and gracefully exiting is the first step. Worst case is that things explode in an infinite loop (as is the case with left-recursion, as in #6492), and implementing a max stack size of 1024 or something should at least give a reasonable approach.

Maybe also some new features like numbered repetition operators "a"{,5} (desugared as the grammar above) and maybe other regexp derived syntax features (reluctant/eager modifiers?) could make it easier to write efficient grammars.

Yeah, I think that could be pretty reasonable. I'm also wondering about parse_sequence() in grammar-parser.cpp, and this section of rewriting +*? operators:

// apply transformation to previous symbol (last_sym_start to end) according to
// rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |

Makes me wonder if there is a better way to do this (more akin to what you wrote above) but I'm still relatively new to grammar engines.

I'm learning a lot about grammar parsing as part of this exercise -- I never read the dragon book, but I'm wondering if I should order myself a copy as part of this work. :)

Fwiw, I've wanted bounded repetitions a few times with this grammar; recursive white space sometimes let's the model spin forever.
@o1lo01ol1o Yeah -- poor handling of whitespace seems to be an issue.

Similar to what @ochafik is saying, I wonder if a good first step would be to do an optimization step where -- after expanding grammars -- that adjacent optional tokens are always collapsed into something more efficient. I.E., ws? ws? being condensed into something like ws{,2}?

@andysalerno
Copy link
Contributor

As a user of the json_schema feature in server, I would very much love faster grammar support :D

I'm building out an agent flow, where I call the LLM an extra time before each user message to extract any function/tool call, detect the intent, etc, which is done using a json schema like { "function_name": "...", "function_invocation": "...", "user_intent": "..."} etc.

And this extra inference step takes up to 5sec on my 3090, even though it's hardly generating any tokens (in fact, I limit it to 20), presumably because of the grammar. After the function calling stage, the actual agent response from the LLM is very speedy, usually sub-1-second depending on length. So it's taking 5x longer to generate only a few tokens for function calling, compared to actually writing out a long response message.

I previously used TabbyAPI for this, and it handled the grammar extremely fast - sub-200ms usually, compared to 5sec in llama.cpp server.

@ochafik
Copy link
Collaborator

ochafik commented May 21, 2024

@andysalerno Would you be able to share a real-world grammar to test one of the fixes that are in flight? (e.g. #7424 )

@andysalerno
Copy link
Contributor

sure, I am using the json_schema property on a request to server at /v1/chat/completions with this schema:

{
    "title": "AnswerFormat",
    "type": "object",
    "properties": {
      "last_user_message_intent": {
        "type": "string"
      },
      "function_name": {
        "type": "string"
      },
      "invocation": {
        "type": "string"
      }
    },
    "required": [
      "last_user_message_intent",
      "function_name",
      "invocation"
    ]
}

I can respond shortly with some measurements showing how long this takes compared to a non-grammar generation with similar length.

@andysalerno
Copy link
Contributor

here is the verbose timings output from the request using the grammar:

    "timings": {
        "prompt_n": 592,
        "prompt_ms": 708.655,
        "prompt_per_token_ms": 1.1970523648648648,
        "prompt_per_second": 835.3853426561585,
        "predicted_n": 37,
        "predicted_ms": 4103.24,
        "predicted_per_token_ms": 110.89837837837837,
        "predicted_per_second": 9.017264405689163
    }

note how it took 4sec to predict 37 tokens, and the predicted_per_token_ms of 110.89, and the predicted_per_second of 9.

and here's the timing for the very next request which does not use a grammar

    "timings": {
        "prompt_n": 140,
        "prompt_ms": 280.92,
        "prompt_per_token_ms": 2.0065714285714287,
        "prompt_per_second": 498.3625231382599,
        "predicted_n": 11,
        "predicted_ms": 292.259,
        "predicted_per_token_ms": 26.569000000000003,
        "predicted_per_second": 37.637848620572846
    }

Note it only took ~300ms for 11 tokens, with a predicted_per_token_ms of 26.6 and a predicted_per_second of 37.6.

@ochafik
Copy link
Collaborator

ochafik commented May 21, 2024

@andysalerno Thanks! Added your grammar to these benchmarks, it's one that seems will benefit from both #7424 (grammar-resampling) & #6811 (grammar-fast)

Show timings
git checkout https://github.com/ochafik/llama.cpp llama.cpp-ochafik
cd llama.cpp-ochafik

for branch in master grammar-fast grammar-resampling grammar-fast-resampling ; do
    echo "#"
    echo "# $branch"
    echo "#"
    git checkout $branch
    make clean
    make -j LLAMA_CURL=1 main
    ./main \
        -mu https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q4_K_M.gguf \
        -j '{"title": "AnswerFormat", "type": "object", "properties": {"last_user_message_intent": {"type": "string" }, "function_name": {"type": "string" }, "invocation": {"type": "string" }}, "required": [ "last_user_message_intent", "function_name", "invocation"]}' \
        -p "Describe a function call of a tool in JSON format after a reminder of the last user message intent." \
        --seed 12345 --no-display-prompt 2>&1 | \
        grep "llama_print_timings"
done
#
# master
#
llama_print_timings:        load time =     173.33 ms
llama_print_timings:      sample time =     484.33 ms /    35 runs   (   13.84 ms per token,    72.26 tokens per second)
llama_print_timings: prompt eval time =      60.38 ms /    21 tokens (    2.88 ms per token,   347.80 tokens per second)
llama_print_timings:        eval time =     639.83 ms /    34 runs   (   18.82 ms per token,    53.14 tokens per second)
llama_print_timings:       total time =    1215.91 ms /    55 tokens
#
# grammar-fast
#
llama_print_timings:        load time =     168.06 ms
llama_print_timings:      sample time =      62.18 ms /    35 runs   (    1.78 ms per token,   562.91 tokens per second)
llama_print_timings: prompt eval time =      59.95 ms /    21 tokens (    2.85 ms per token,   350.28 tokens per second)
llama_print_timings:        eval time =     637.86 ms /    34 runs   (   18.76 ms per token,    53.30 tokens per second)
llama_print_timings:       total time =     767.84 ms /    55 tokens
#
# grammar-resampling
#
llama_print_timings:        load time =     169.90 ms
llama_print_timings:      sample time =     175.08 ms /    49 runs   (    3.57 ms per token,   279.87 tokens per second)
llama_print_timings: prompt eval time =      60.24 ms /    21 tokens (    2.87 ms per token,   348.62 tokens per second)
llama_print_timings:        eval time =     653.13 ms /    35 runs   (   18.66 ms per token,    53.59 tokens per second)
llama_print_timings:       total time =     905.68 ms /    56 tokens
#
# grammar-fast-resampling
#
llama_print_timings:        load time =     166.53 ms
llama_print_timings:      sample time =      32.80 ms /    49 runs   (    0.67 ms per token,  1493.72 tokens per second)
llama_print_timings: prompt eval time =      60.22 ms /    21 tokens (    2.87 ms per token,   348.74 tokens per second)
llama_print_timings:        eval time =     654.92 ms /    35 runs   (   18.71 ms per token,    53.44 tokens per second)
llama_print_timings:       total time =     756.98 ms /    56 tokens

@ExtReMLapin
Copy link
Contributor

Example of slow grammar :

root ::= dateforced | string
dateforced ::=  "\""  "Date lol"  "\"" 
string ::= EntityTypeNonDate 
EntityTypeNonDate ::= "\""  ( [^D\x00-\x40\U0000005B-\UFFFFFFFF] | "D" [^a\x00-\x60\U0000007B-\UFFFFFFFF] | "Da" [^t\x00-\x60\U0000007B-\UFFFFFFFF] | "Dat" [^e\x00-\x60\U0000007B-\UFFFFFFFF]) ASCIIEntityNameContinue{0,15}  "\""
ASCIICharLower ::= [a-z]
ASCIICharUpper ::= [A-Z]
ASCIIEntityName ::= ASCIIWordFirst (ASCIIWordNext){0,3}
ASCIIEntityNameContinue ::= (ASCIIWordNext){0,3}
ASCIIWordFirst ::= ASCIICharUpper ASCIICharLower{2,20}
ASCIIWordNext ::= ("-"|" ")? ASCIICharUpper? ASCIICharLower{2,20}

./llama-cli -m mistral-7b-instruct-v0.2.Q5_K_M.gguf -ngl 999999 --seed 0 --temp 0 -p "[INST]Who are you ? answer with quotes[/INST]" -n 512 --grammar-file ./grammar_fallback.gbnf

@github-actions github-actions bot removed the stale label Aug 10, 2024
@github-actions github-actions bot added the stale label Sep 9, 2024
@ggerganov ggerganov removed the stale label Sep 16, 2024
clarismiranda added a commit to clarismiranda/llama.cpp that referenced this issue Oct 11, 2024
@github-actions github-actions bot added the stale label Oct 17, 2024
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this issue Oct 24, 2024
…anov#9833

Grammar memo

Co-Authored-By: Clarissa Miranda <80654285+clarissamiranda@users.noreply.github.com>
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this issue Oct 24, 2024
ggerganov#9833"

This reverts commit 4cbf5c392af62252a69e17143e8a81d771ca6f8a.
Copy link
Contributor

github-actions bot commented Nov 1, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Nov 1, 2024
@CamJN
Copy link

CamJN commented Nov 1, 2024

Is it possible to reopen this? This issue needs fixing.

@HanClinto HanClinto reopened this Nov 1, 2024
@HanClinto
Copy link
Collaborator

HanClinto commented Nov 1, 2024

Is it possible to reopen this? This issue needs fixing.

To help me understand the particulars, how big of a concern / obstacle is this for you and your workflows? I haven't worked on this problem in a while, but I would like to get back to it at some point, and I keep mulling the issue over in my mind.

Right now, the biggest obstacle that I see is that the number of stack permutations in the case of ambiguous grammars explodes so high in some cases, that it's very difficult to manage.

Some options that we could consider:

  • CFG graph compilation: This is where I've spent a lot of my thought, but I don't have much prior experience in this area, so I'm having to learn as I go. @hoehrmann talks about this in a series of comments on #7572, but the basic idea (as I understand it) is to compile the context-free grammar into a static graph ahead of time. I don't know how this would look with very complex examples, but hopefully this would add a delay only to the warmup of our inference call as the grammar is compiled, but then the per-token cost should be minimal (and perhaps even constant-time). I don't know much about this region of computer science, but from what I can tell, I think that this is similar to how regex compilation works? If anyone more familiar with this subject would be willing to weigh in, I would love to learn more about this.

  • Alternate parsers / beam-search: We could switch away from a stack-based parser to a method that requires backtracing. In practice, I think this might look similar to the idea proposed about combining the speculative example with grammars in llama : combined beam search + grammar sampling strategy #2923

  • Random culling: One thing I've done in some monte carlo searches is to not do a truly exhaustive search of all permutations, but instead -- whenever my search tree gets above a certain threshold (which can be specified arbitrarily, or based on available memory, or elapsed time, or whatever) then the system will randomly cull the search forest and trust that the remaining options will be "good enough". This might be the fastest way to get past road blocks that people are having with this feature and improve the speed of the feature (at the cost of sometimes missing valid branches of generation).

  • Add alternate grammar format(s): Instead of building our own grammar compilation engine, why not use an off-the-shelf grammar engine that already supports compilation (such as regex)? One thing that I like about the refactoring in llama : refactor sampling v2 #9294 is how clean it makes the interface for sampling modifiers and sampling constraints, and it feels like we should be able to build on this to create an alternate grammar format. I think I've seen other inference engines do this, where instead of writing their own grammar engine they hook into an existing regex library and use their compiler to build their CFG graph and use that for efficient (sometimes seems to be linear-time?) grammar-constrained generation. If people still need the flexibility of GBNF then we wouldn't need to remove that grammar engine (and especially we want to retain it for backwards compatibility), but ideally the secondary grammar format should be able to live alongside the original, in the same way that various samplers coexist peacefully in the current post-refactoring setup.

I'm very outside of my comfort zone of expertise, and I don't know how viable / attractive any of these paths are, but I'd love to hear others' thoughts on it.

@github-actions github-actions bot removed the stale label Nov 2, 2024
@mmoskal
Copy link

mmoskal commented Nov 9, 2024

I just did a very draft PR #10224 that integrates an alternative grammar library. You can give it a spin and see if it helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics refactoring Refactoring
Projects
Status: Todo
Development

No branches or pull requests