Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stochastic speculative sampling #5625

Merged

Conversation

mscheong01
Copy link
Collaborator

@mscheong01 mscheong01 commented Feb 21, 2024

Closes #5384

  • This pull request addresses the verification aspect of batched stochastic speculative sampling, following Algorithm 2 outlined in the paper available at: https://arxiv.org/pdf/2305.09781.pdf.
image
  • The implementation deviates slightly from the specified method by traversing draft sequences sequentially (from 0 to ~) instead of randomly selecting them (𝑠 ∼ rand(ℋ)). I would appreciate your feedback on whether this alteration should be corrected.

  • Additionally, a new method llama_sampling_probability_distribution has been introduced in sampling.h to retrieve the probability distribution of the target model for use in residual distribution calculations. While it's noted that the distribution could be obtained through ctx_sampling->cur, it's important to maintain consistency with the distribution when executed through ./main, considering factors such as penalties.

  • To achieve a comprehensive implementation of stochastic speculative decoding, it's essential to incorporate stochastic drafting for sampling drafts. Feedback is welcomed on how existing parameters like p_split and p_accept should be integrated with stochastic drafting. Once this is clarified, I will refine the drafting code and remove the (WIP) from the PR title.

  • Seeking validation on the implementation. Kindly provide feedback on any identified issues or concerns. Thank you.

@mscheong01
Copy link
Collaborator Author

test run on my m1 pro mac:

./speculative -m models/llama-2-13b.Q5_K_M.gguf --model-draft models/llama-2-13b.Q5_K_M.gguf -p "Simple python quicksort function:\n" -n 200 -e --color --log-enable --temp 1
image

@JohannesGaessler
Copy link
Collaborator

I recently worked with these files and should be able to review. However, I'm currently attending a scientific conference and will only be available next week.

@ggerganov
Copy link
Owner

The implementation deviates slightly from the specified method by traversing draft sequences sequentially (from 0 to ~) instead of randomly selecting them (𝑠 ∼ rand(ℋ)). I would appreciate your feedback on whether this alteration should be corrected.

From a quick look, I believe the authors are selecting random child nodes from the draft tree at each depth of the tree. While in the proposed implementation, I think you are testing sequentially the drafted tokens from a list.

In other words, randomly selecting the child nodes makes sense only for n_seq_dft > 1 in which case we draft a tree of tokens. By default, n_seq_dft == 1 and we draft just a list

Does that make sense?

@mscheong01
Copy link
Collaborator Author

@ggerganov Yes, that is correct. I was referring to cases where n_seq_dft > 1.

@ggerganov
Copy link
Owner

I see now. I think the proposed approach is equivalent to the one in the paper, but I could be missing something.

To achieve a comprehensive implementation of stochastic speculative decoding, it's essential to incorporate stochastic drafting for sampling drafts.

The p_split parameter is currently used to decide when to create a new branch in the draft tree. Reading the SpecInfer paper, the authors do not present a specific strategy for creating the draft tree:

image

So this part can remain the same as it is.

Regarding p_accept - I suppose this parameter can be completely removed?
Reading the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf), we have this:

image

IIUC we no longer have to draft the tokens greedily, but instead apply the described strategy. In this case, I think the p_accept no longer has meaning

@mscheong01
Copy link
Collaborator Author

mscheong01 commented Feb 27, 2024

@ggerganov Understood. I wasn't certain whether using a top-1 or top-k (or p_split) drafting method would affect the equivalence of the output distribution to the target model distribution (I'm still unsure about the mathematical implications), but based on the information from the SpecInfer paper, it seems we can keep it as it is for now.
I'll review papers like SpecTr and SpecTr++, which propose optimal drafting and verification methods for batched speculative decoding, and make adjustments as necessary in future PR.

In the meantime, I believe replacing p_accept with a parameter representing the number of draft tokens to sample per decoding step is a sensible decision, given its critical role in influencing the efficiency and performance of the decoding process. I'll make this update and remove the (WIP) tag.

edit: Turns out that we already have parameter n_draft for this so all I would have to do is remove p_accept

@mscheong01 mscheong01 changed the title (WIP) Implement stochastic speculative sampling Implement stochastic speculative sampling Feb 27, 2024
@JohannesGaessler
Copy link
Collaborator

I read the paper and I do not understand how their proposed sampling method can be better than what they call "naive sampling". Fundamentally, if the probability distribution of the sampled tokens is constant, then the probability of the sampled sequence being in the draft tree is also constant. So it doesn't matter what tricks you use for acceptance/rejection, it's not going to make any difference whatsoever.

Is there any evidence that quantitatively shows the method from the paper being superior to naive sampling? In #5479 I implemented some code that tests lookup decoding in terms of e.g. acceptance rate on a large text corpus in order to get sufficient statistics. It may make sense to implement something like this for speculative decoding as well.

The implementation deviates slightly from the specified method by traversing draft sequences sequentially (from 0 to ~) instead of randomly selecting them (𝑠 ∼ rand(ℋ)). I would appreciate your feedback on whether this alteration should be corrected.

In order to preserve the probability distribution of the LLM outputs the order in which the draft sequences are traversed must be random. This is because the algorithm stops at the first accepted continuation. If the order is non-random this therefore biases the drafting towards those tokens with a lower index (with the exact bias being implementation-specific).

@ggerganov
Copy link
Owner

I'm testing the PR using instruction-tuned Codellama and it seems the generation is not OK for temp > 0.

make -j speculative && ./speculative -m ./models/codellama-34b-instruct/ggml-model-f16.gguf -md ./models/codellama-7b/ggml-model-q4_1.gguf -p "[INST] Write Dijkstra algorithm in C++ (4 spaces indentation + detailed comments) + sample usage [/INST]" -e -ngl 99 -n 4096 -c 4096 -s 20 --top_k 40 --draft 5 --color -s 1 --no-penalize-nl --repeat-penalty 1.0 --temp 1.0

Here is what we get on master:

image
 [INST] Write Dijkstra algorithm in C++ (4 spaces indentation + detailed comments) + sample usage [/INST]  Here is a sample implementation of the Dijkstra algorithm in C++ with 4 space indentation and detailed comments:

[code]
#include <iostream>
#include <vector>
#include <queue>

// A struct to represent a node in the graph
struct Node {
    int id; // the id of the node
    int dist; // the distance of the node from the starting node
    int parent; // the id of the parent node
};

// A struct to represent an edge in the graph
struct Edge {
    int from; // the id of the starting node
    int to; // the id of the ending node
    int weight; // the weight of the edge
};

// A function to add an edge to the graph
void addEdge(std::vector<Node>& nodes, std::vector<Edge>& edges, int from, int to, int weight) {
    // add the edge to the edges vector
    edges.push_back(Edge{ from, to, weight });

    // add the nodes to the nodes vector
    if (nodes.size() < from) {
        nodes.resize(from);
    }
    if (nodes.size() < to) {
        nodes.resize(to);
    }
    nodes[from].id = from;
    nodes[from].dist = 0;
    nodes[from].parent = -1;
    nodes[to].id = to;
    nodes[to].dist = INT_MAX;
    nodes[to].parent = -1;
}

// A function to find the shortest path from the starting node to all other nodes using Dijkstra's algorithm
void dijkstra(std::vector<Node>& nodes, std::vector<Edge>& edges, int start) {
    // create a priority queue to hold the nodes
    std::priority_queue<Node, std::vector<Node>, NodeComparator> queue;

    // initialize the distance of the starting node to 0
    nodes[start].dist = 0;

    // enqueue the starting node
    queue.push(nodes[start]);

    while (!queue.empty()) {
        // dequeue the current node
        Node current = queue.top();
        queue.pop();

        // if the current node has not been visited, visit it
        if (current.id != -1) {
            // mark the current node as visited
            current.id = -1;

            // for each neighbor of the current node
            for (int i = 0; i < edges.size(); i++) {
                // if the neighbor has not been visited, visit it
                if (edges[i].to != current.id) {
                    // calculate the distance of the neighbor
                    int dist = current.dist + edges[i].weight;

                    // if the distance is less than the current distance, update the distance
                    if (dist < nodes[edges[i].to].dist) {
                        nodes[edges[i].to].dist = dist;
                        nodes[edges[i].to].parent = current.id;
                    }
                }
            }

            // enqueue the neighbors of the current node
            for (int i = 0; i < edges.size(); i++) {
                if (edges[i].to != current.id) {
                    queue.push(nodes[edges[i].to]);
                }
            }
        }
    }
}

int main() {
    // create a graph with 5 nodes
    std::vector<Node> nodes;
    nodes.resize(5);

    // add edges to the graph
    addEdge(nodes, edges, 0, 1, 1);
    addEdge(nodes, edges, 0, 2, 2);
    addEdge(nodes, edges, 1, 2, 3);
    addEdge(nodes, edges, 2, 3, 4);
    addEdge(nodes, edges, 3, 4, 5);

    // find the shortest path from node 0 to all other nodes
    dijkstra(nodes, edges, 0);

    // print the shortest path from node 0 to node 4
    std::cout << "Shortest path from node 0 to node 4: " << nodes[4].dist << std::endl;

    return 0;
}
[/code]
This implementation uses a priority queue to hold the nodes, and a vector to represent the edges. The `dijkstra` function takes the graph as input, and returns the shortest path from the starting node to all other nodes. The `addEdge` function adds an edge to the graph, and the `Node` and `Edge` structs represent a node and an edge in the graph, respectively.

To use this implementation, you can simply add edges to the graph using the `addEdge` function, and then call the `dijkstra` function to find the shortest path from the starting node to all other nodes. You can then print the shortest path from the starting node to a specific node using the `nodes` vector.

Note that this implementation assumes that the graph is directed and has positive edge weights. If the graph is undirected, you can simply add edges in both directions to the graph, and if the edge weights are negative, you can use a different data structure to represent the edges, such as a vector of pairs.

encoded   28 tokens in    0.350 seconds, speed:   80.076 t/s
decoded 1242 tokens in   72.499 seconds, speed:   17.131 t/s

n_draft   = 5
n_predict = 1242
n_drafted = 1035
n_accept  = 914
accept    = 88.309%

draft:

llama_print_timings:        load time =     438.57 ms
llama_print_timings:      sample time =    2063.97 ms /     1 runs   ( 2063.97 ms per token,     0.48 tokens per second)
llama_print_timings: prompt eval time =      59.95 ms /    28 tokens (    2.14 ms per token,   467.02 tokens per second)
llama_print_timings:        eval time =   16253.87 ms /  1363 runs   (   11.93 ms per token,    83.86 tokens per second)
llama_print_timings:       total time =   72849.27 ms /  1391 tokens

target:

llama_print_timings:        load time =    3038.51 ms
llama_print_timings:      sample time =      22.88 ms /  1242 runs   (    0.02 ms per token, 54278.47 tokens per second)
llama_print_timings: prompt eval time =   46517.41 ms /  1316 tokens (   35.35 ms per token,    28.29 tokens per second)
llama_print_timings:        eval time =    7679.47 ms /    74 runs   (  103.78 ms per token,     9.64 tokens per second)
llama_print_timings:       total time =   73315.93 ms /  1390 tokens

The result is the same after each run because the seed is fixed to -s 1.

With the PR, the generation is different each time and it seems to not end when it should:

 [INST] Write Dijkstra algorithm in C++ (4 spaces indentation + detailed comments) + sample usage [/INST]  Here is the code:

[code]
#include <stack>
#include <queue>
#include <iostream>
#include <climits>
#include <vector>
#include <algorithm>
#include <cmath>

using namespace std;

struct Node {
    int x;
    int y;
    int dist;
    Node(int x, int y, int dist) : x(x), y(y), dist(dist) {}
};

int find_min(vector<int> &dist, vector<bool> &visited) {
    int min = INT_MAX;
    int min_index = -1;
    for (int i = 0; i < dist.size(); i++) {
        if (!visited[i] && dist[i] < min) {
            min = dist[i];
            min_index = i;
        }
    }
    return min_index;
}

void dijkstra(vector<vector<int>> &graph, int start, int end) {
    int n = graph.size();
    vector<int> dist(n, INT_MAX);
    vector<bool> visited(n, false);
    vector<Node> nodes;
    dist[start] = 0;
    nodes.push_back(Node(start, 0, 0));
    while (!nodes.empty()) {
        Node node = nodes.back();
        nodes.pop_back();
        int u = node.x;
        int d = node.dist;
        visited[u] = true;
        for (int i = 0; i < graph[u].size(); i++) {
            int v = graph[u][i];
            if (dist[v] > dist[u] + d) {
                dist[v] = dist[u] + d;
                nodes.push_back(Node(v, dist[v], d));
            }
        }
    }
    cout << "Distance from " << start << " to " << end << " is " << dist[end] << endl;
}

int main() {
    int n;
    cin >> n;
    vector<vector<int>> graph(n, vector<int>());
    for (int i = 0; i < n; i++) {
        int m;
        cin >> m;
        for (int j = 0; j < m; j++) {
            int k;
            cin >> k;
            graph[i].push_back(k);
        }
    }
    int start, end;
    cin >> start >> end;
    dijkstra(graph, start, end);
    return 0;
}
[/code]

[Explanation]

The algorithm is pretty simple. First, we initialize the distance array with INT_MAX. Then, we push the starting node into the queue. We keep popping the nodes from the queue and updating the distance array and the queue. We keep doing this until the queue is empty.

[/Explanation]

Here is the sample usage:

5 5
0 1 2
1 2 1
2 3 1
3 4 1
4 0 1
0 4

[Solution]

Distance from 0 to 4 is 4

[/Solution]

[Explanation]

The graph is represented as a 2D array. The first dimension is the node number and the second dimension is the neighboring nodes.

[/Explanation]

[Optimization]

You can optimize the code by using a priority queue instead of a queue.

[/Optimization]

encoded   28 tokens in    0.350 seconds, speed:   80.102 t/s
decoded  863 tokens in   39.322 seconds, speed:   21.947 t/s

n_draft   = 5
n_predict = 863
n_drafted = 745
n_accept  = 713
accept    = 95.705%

draft:

llama_print_timings:        load time =     442.26 ms
llama_print_timings:      sample time =      13.07 ms /   745 runs   (    0.02 ms per token, 57018.22 tokens per second)
llama_print_timings: prompt eval time =      59.84 ms /    28 tokens (    2.14 ms per token,   467.91 tokens per second)
llama_print_timings:        eval time =   10561.18 ms /   895 runs   (   11.80 ms per token,    84.74 tokens per second)
llama_print_timings:       total time =   39672.41 ms /   923 tokens

target:

llama_print_timings:        load time =    3022.01 ms
llama_print_timings:      sample time =    1438.03 ms /   150 runs   (    9.59 ms per token,   104.31 tokens per second)
llama_print_timings: prompt eval time =   27311.96 ms /   921 tokens (   29.65 ms per token,    33.72 tokens per second)
llama_print_timings:        eval time =     101.06 ms /     1 runs   (  101.06 ms per token,     9.89 tokens per second)
llama_print_timings:       total time =   40142.13 ms /   922 tokens

Although it's consistently faster, I think the quality is affected noticeably, but it's a bit difficult to demonstrate.

I think the first goal is to make the results deterministic when temp > 0 and the seed is fixed

@mscheong01
Copy link
Collaborator Author

@ggerganov

With the PR, the generation is different each time and it seems to not end when it should

It looks like this is caused by the randomly selected value r at examples/speculative/speculative.cpp#L232. Making this value to be consistent given the same seed will hopefully fix this behavior.

Regarding the quality, It might be from the fact that the output distribution is not equal to the target distribution due to top-1 drafting

one question about the seed: given the same seed, should speculative output the same sequence as main? I'm not sure if that's possible or not at the moment.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Feb 27, 2024

I tried a simple Python script to check whether the method works:

#!/usr/bin/env python3

import numpy as np
import random

SAMPLE_SIZE = 10000
VOCAB_SIZE = 10
BRANCHING_RATIO = 4
TREE_DEPTH = 4

X = np.arange(VOCAB_SIZE)

P_LLM  = np.exp(-X)
P_LLM /= np.sum(P_LLM)

P_SSM  = 1 / (100 - X)
P_SSM /= np.sum(P_SSM)


def sample(probs):
    x = np.random.rand()
    cumsum = 0.0
    for i, p_i in enumerate(probs):
        cumsum += p_i
        if x < cumsum:
            return i
    assert False


n_accept_naive = 0
n_accept_spec_infer = 0

for _ in range(SAMPLE_SIZE):
    trees = [[0]]

    for _ in range(TREE_DEPTH):
        trees_new = []
        for tree in trees:
            sampled_tokens = []
            while len(sampled_tokens) < BRANCHING_RATIO:
                token = sample(P_SSM)
                if token in sampled_tokens:
                    continue

                trees_new.append(tree + [token])
                sampled_tokens.append(token)

        trees = trees_new

    sequence_llm = [0] + [sample(P_LLM) for _ in range(TREE_DEPTH)]

    max_match = 0
    for tree in trees:
        for depth in range(TREE_DEPTH):
            if tree[1:depth+1] == sequence_llm[1:depth+1]:
                max_match = max(max_match, depth)

    n_accept_naive += max_match

    n_accept_i = 0
    norm = 1.0
    while trees and n_accept_i < TREE_DEPTH:
        random.shuffle(trees)

        token = trees[0][n_accept_i]
        if np.random.rand() < (P_LLM[token] / norm) / P_SSM[token]:
            trees = list(filter(lambda t: t[n_accept_i] == token, trees))

            n_accept_i += 1
            norm = 1.0
        else:
            trees = list(filter(lambda t: t[n_accept_i] != token, trees))

            norm -= P_LLM[token]

    n_accept_spec_infer += n_accept_i

print(f"naive: {n_accept_naive / SAMPLE_SIZE}")
print(f"SpecInfer: {n_accept_spec_infer / SAMPLE_SIZE}")

Edit: the above script has a bug!

The results are:

naive: 0.6027
SpecInfer: 3.2117

So assuming my code is correct the method does seem to work. What I think is happening is that even though the ultimate probability distribution of the LLM does not change, the conditional probability distribution of the LLM given the SSM results does change. In essence, because the probabilities of those tokens in the tree get scaled up by dividing them by the probabilities that they end up in the tree in the first place, there is a correlation between the tokens sampled by the SSM and the tokens sampled by the LLM. So even though the output distribution doesn't change the probability of the draft being correct does. It's similar to how the Metropolis-Hastings algorithm exploits autocorrelation to increase the rate of convergence over simple Monte Carlo methods.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Feb 28, 2024

Related discussion: flexflow/FlexFlow#1302

I also noticed that my Python script had a bug regarding the normation. This version should be fixed:

#!/usr/bin/env python3

import numpy as np
import random

SAMPLE_SIZE = 10000
VOCAB_SIZE = 10
BRANCHING_RATIO = 4
TREE_DEPTH = 4

X = np.arange(VOCAB_SIZE)

P_LLM  = np.exp(-X)
P_LLM /= np.sum(P_LLM)

P_SSM  = 1 / (100 - X)
P_SSM /= np.sum(P_SSM)


def sample(probs):
    x = np.random.rand()
    cumsum = 0.0
    for i, p_i in enumerate(probs):
        cumsum += p_i
        if x < cumsum:
            return i
    assert False


n_accept_naive = 0
n_accept_spec_infer = 0

for _ in range(SAMPLE_SIZE):
    trees = [[0]]

    for _ in range(TREE_DEPTH):
        trees_new = []
        for tree in trees:
            sampled_tokens = []
            while len(sampled_tokens) < BRANCHING_RATIO:
                token = sample(P_SSM)
                if token in sampled_tokens:
                    continue

                trees_new.append(tree + [token])
                sampled_tokens.append(token)

        trees = trees_new

    sequence_llm = [0] + [sample(P_LLM) for _ in range(TREE_DEPTH)]

    max_match = 0
    for tree in trees:
        for depth in range(TREE_DEPTH):
            if tree[1:depth+1] == sequence_llm[1:depth+1]:
                max_match = max(max_match, depth)

    n_accept_naive += max_match

    n_accept_i = 0
    p_llm = np.array(P_LLM)

    while trees and n_accept_i < TREE_DEPTH:
        random.shuffle(trees)

        token = trees[0][n_accept_i]
        if np.random.rand() < p_llm[token] / P_SSM[token]:
            trees = list(filter(lambda t: t[n_accept_i] == token, trees))

            n_accept_i += 1
            p_llm = np.array(P_LLM)
        else:
            trees = list(filter(lambda t: t[n_accept_i] != token, trees))

            p_llm = np.maximum(0.0, p_llm - P_SSM)
            p_llm /= np.sum(p_llm)

    n_accept_spec_infer += n_accept_i

print(f"naive: {n_accept_naive / SAMPLE_SIZE}")
print(f"SpecInfer: {n_accept_spec_infer / SAMPLE_SIZE}")

Edit: These are the fixed results:

naive: 0.5967
SpecInfer: 2.4448

examples/speculative/speculative.cpp Outdated Show resolved Hide resolved

const std::string token_str = llama_token_to_piece(ctx_tgt, id);
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
for (int s = 0; s < n_seq_dft; ++s) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said before, unless my math is wrong order in which the drafts are iterated over must be random. But in any case, given how tricky this implementation is I think we should not make any changes to the algorithm unless we can confirm that the results stay the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. One way to implement this would be random sorting the order of sequences to evaluate, if there is a computationally cheap method to do so.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could just select a random sequences index to check for each iteration.

examples/speculative/speculative.cpp Show resolved Hide resolved
examples/speculative/speculative.cpp Outdated Show resolved Hide resolved
examples/speculative/speculative.cpp Outdated Show resolved Hide resolved
@mscheong01
Copy link
Collaborator Author

Description of what I did on commit [94f6256]:
I attempted to fix the undeterministic output by adding srand() associated with the provided seed value in commit [6afc1f6]
However, the output kept changing.
On further investigation, it seemed like rand() calls from internal code seemed to interfere with the randomly generated value r from being consistent.
So I replaced the rand() call with a separate mt19937 + uniform distribution method.
Although I wasn't able to test this thoroughly due to my flimsy test hardware, multiple calls with same seed seem to return equal output. I'll try testing it more when I get my hands on a gpu node this week.

@mscheong01
Copy link
Collaborator Author

@JohannesGaessler I've implemented random selection of sequences to verify in commit [2ad3f7c]

make -j && ./speculative -m models/llama-2-13b.Q5_K_M.gguf --model-draft models/llama-2-13b.Q5_K_M.gguf -p "Simple python quicksort function:\n\`\`\`python\n" -n 200 -e --color --log-enable --temp 1 -np 3 
image

logs show that sequences to verify at each step is selected randomly

image

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been doing some experiments on A100 with CodeLlama instruct. The PR seems to produce correct results and there seems to be some performance gain for temp > 0.0 compared to master, though my tests are not very extensive.

Some observations from the experiments:

  • We continue to gain most from speculative decoding when using F16 target model. For example, with 34B F16 target + 7B Q4_0 draft, we can get a speedup of up to x3 using --draft 16

  • Using tree-based speculative decoding (--np > 1) seems to slightly improve the performance for F16 target models, but does not help much for quantum models

  • With quantum target models, using 34B Q8_0 + 7B Q4_0 we can get up to x1.5 speedup with --draft 4. And for 34B Q4_K + 7B Q4_0 I didn't observe significant speed-up from the speculative decoding for different --draft values

To determine the optimal --draft value, run the following command and pick the largest value for which the speed scales mostly linearly:

LLAMA_CUBLAS=1 make -j llama-bench && ./llama-bench \
-m models/codellama-34b-instruct/ggml-model-q4_0.gguf \
-m models/codellama-34b-instruct/ggml-model-q4_k.gguf \
-m models/codellama-34b-instruct/ggml-model-q8_0.gguf \
-m models/codellama-34b-instruct/ggml-model-f16.gguf \
-ngl 99 -p 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,512
  Device 0: NVIDIA A100-SXM4-80GB, compute capability 8.0, VMM: yes
| model                          |       size |     params | backend    | ngl | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------- | ---------------: |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 1       |    44.15 ± 15.32 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 2       |     97.81 ± 0.39 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 3       |    124.42 ± 1.91 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 4       |    146.06 ± 1.11 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 5       |    162.86 ± 0.96 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 6       |    182.57 ± 1.10 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 7       |    195.39 ± 0.64 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 8       |    202.36 ± 0.50 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 9       |    126.62 ± 0.27 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 10      |    140.13 ± 0.45 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 11      |    153.46 ± 0.40 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 12      |    166.44 ± 0.38 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 13      |    149.57 ± 0.15 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 14      |    160.48 ± 0.16 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 15      |    170.85 ± 0.35 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 16      |    181.80 ± 0.40 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 32      |    206.25 ± 0.65 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | pp 512     |   1638.98 ± 1.68 |
| llama 34B Q4_0                 |  17.74 GiB |    33.74 B | CUDA       |  99 | tg 128     |     52.20 ± 0.03 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 1       |     45.63 ± 0.48 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 2       |     73.43 ± 0.25 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 3       |     89.17 ± 0.34 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 4       |     97.59 ± 0.26 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 5       |    100.93 ± 0.27 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 6       |    106.41 ± 0.30 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 7       |    113.68 ± 0.21 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 8       |    105.84 ± 0.16 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 9       |     90.76 ± 0.19 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 10      |    100.53 ± 0.13 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 11      |    110.19 ± 0.13 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 12      |    119.58 ± 0.36 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 13      |    105.82 ± 0.15 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 14      |    113.56 ± 0.17 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 15      |    121.41 ± 0.21 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 16      |    128.87 ± 0.24 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 32      |    148.65 ± 0.15 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | pp 512     |   1653.67 ± 4.24 |
| llama 34B Q4_K - Medium        |  18.83 GiB |    33.74 B | CUDA       |  99 | tg 128     |     45.81 ± 0.03 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 1       |     34.67 ± 0.28 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 2       |     64.67 ± 0.16 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 3       |     92.92 ± 0.37 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 4       |    108.26 ± 0.47 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 5       |    135.18 ± 0.34 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 6       |    134.88 ± 0.75 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 7       |    150.72 ± 0.70 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 8       |    160.81 ± 0.50 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 9       |     80.38 ± 0.24 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 10      |     89.21 ± 0.17 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 11      |     97.64 ± 0.20 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 12      |    106.23 ± 0.21 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 13      |     94.00 ± 0.11 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 14      |    100.91 ± 0.18 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 15      |    107.95 ± 0.13 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 16      |    114.81 ± 0.10 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 32      |    129.96 ± 0.10 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | pp 512     |   1870.12 ± 4.22 |
| llama 34B Q8_0                 |  33.39 GiB |    33.74 B | CUDA       |  99 | tg 128     |     34.71 ± 0.02 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 1       |     19.20 ± 0.71 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 2       |     41.40 ± 0.03 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 3       |     61.37 ± 0.41 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 4       |     81.70 ± 0.30 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 5       |    101.00 ± 0.74 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 6       |    121.58 ± 0.29 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 7       |    140.73 ± 0.56 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 8       |    160.49 ± 0.68 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 9       |    178.98 ± 0.52 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 10      |    198.23 ± 1.23 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 11      |    217.93 ± 0.80 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 12      |    237.09 ± 0.66 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 13      |    256.82 ± 0.78 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 14      |    274.49 ± 1.05 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 15      |    294.11 ± 1.28 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 16      |    312.92 ± 1.07 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 32      |    610.32 ± 2.21 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | pp 512     |  2429.29 ± 10.59 |
| llama 34B F16                  |  62.85 GiB |    33.74 B | CUDA       |  99 | tg 128     |     20.04 ± 0.01 |

build: c7613540 (2234)

For example, for the F16 model, we can go up to --draft 16 since the speed for -p 16 is almost 16 times faster than the speed for -p 1. However, for Q4_K there is no point in using --draft more than 2 because the speed does not scale so well. Hence, the poor speculative decoding results for that model

Would be nice if more people give this branch a try and report any issues and/or results

examples/speculative/speculative.cpp Outdated Show resolved Hide resolved

llama_sample_softmax(ctx_main, &cur_p);
return cur_p;
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there some way to reuse the code from llama_sampling_sample and avoid the duplication? Also, this function does not take into account the grammar - is this correct?

Copy link
Collaborator Author

@mscheong01 mscheong01 Mar 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A solution would be to have llama_sampling_sample utilize llama_sample_probability_distribution internally. However, this approach appeared too intrusive to the existing codebase at the time, which led to duplicating the code instead.

+ Yes it seems like I should apply grammar constraints here

@ggerganov ggerganov added the need feedback Testing and feedback with results are needed label Mar 2, 2024
@ggerganov ggerganov merged commit 6d341ab into ggerganov:master Mar 4, 2024
43 of 59 checks passed
NeoZhangJianyu pushed a commit to NeoZhangJianyu/llama.cpp that referenced this pull request Mar 5, 2024
* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix ggerganov#5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README
abhilash1910 pushed a commit that referenced this pull request Mar 5, 2024
* fix mul_mat fault in cpy_f32_f16

* rm unused function

* add wait() for memcpy

* restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl

* fix format issue

* llama : fix segfault from unknown model arch name (#5820)

* llama : fix segfault from unknown model arch name

* llama : make all LLM maps const

This also requires using `std::map::at` instead of its `operator[]`
which does not exist for const maps.

* llama : name LLM_ARCH_UNKNOWN to "(unknown)"

This avoids errors from `std::map::at` when
getting the general name of the model architecture.
Using "(unknown)" instead of an empty string as per suggestion
#5820 (comment)

* llama : remove redundant inner const for LLM_TENSOR_NAMES

The extra const won't do anything here as const maps
return const references to values.

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : remove redundant nullptr check in llm_arch_from_string

Since LLM_ARCH_NAMES is a const map, no spurious elements
with a NULL name are inserted anymore, so this check is dead code.

---------

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : refactor internal quantization functions (#5830)

* scripts : add pod-llama.sh

* ggml : IQ3_S improvements (#5829)

* iq3_s: somewhat faster AVX2 dot product

On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using
16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s.
PP-512 increases to 28.5 t/s from 23.8 t/s.

* iq3_s: somewhat faster ARM_NEON dot product

Still dog slow - 10.7 t/s up from 9.9 t/s.

* iq3_s: another small ARM_NEON improvement

10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick
that works best on AVX2.

* iq3_s: minor improvement on Metal

49.4 t/s -> 50.3 t/s

* iq3_s: PPL improvement

E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653.

* iq3_s: use new grid everywhere

* Fix ARM_NEON

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

* convert-hf : make model class definitions self-contained (#5825)

* convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (#5821)

* ggml : fix IQ3_S AVX implementation (#5834)

ggml-ci

* llama : add abort_callback to interrupt computation (#5409)

* using abort_callback from ggml to stop llama computation

* format fix

* a brief explaining comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: tests: passkey challenge /  self-extend with context shift demo (#5832)

* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test

* flake.lock: Update (#5842)

Flake lock file updates:

• Updated input 'flake-parts':
    'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01)
  → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01)
• Updated input 'flake-parts/nixpkgs-lib':
    'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29)
• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* server : init http requests thread pool with --parallel if set (#5836)

* ci : schedule slow server tests only on Release or on demand (#5839)

* llama : fix llama_copy_state_data with fragmented KV cache (#5840)

The row size of the saved states was based on kv_self.head while
it should be based on llama_kv_cache_cell_max.

Existing session files should still work.

* llama : fix llama_kv_cache_cell_max inability to return 1

I've also changed its return type to uint32_t,
because this function is always used to set the value of uint32_t variables,
and because the index already has this type.

* llama : fix state size calculation

Some bytes in the state were unaccounted for in llama_get_state_size.
Since the logits reserve so much space, it did not cause problems.

* gguf-dump : support i-quants (#5841)

Co-authored-by: Black_Fox <radekliska@gmail.com>

* llama : allow for user specified embedding pooling type (#5849)

* allow for user specified pooling type

* llama : use enum types over int

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* readme : add API changes section

* cuda : fix data race in soft max (#5853)

* main : support special tokens as reverse/anti prompt (#5847)

* Support special tokens as reverse/anti prompt.

* Tokenize antiprompts only once.

* main : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* common : use LLAMA_DEFAULT_SEED (#5855)

* add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)

* cuda: fix group_norm

* cuda: add batch inference support for ggml_pad/ggml_upscale

* add ggml_arrange

* add ggml_timestep_embedding

* update ggml_arange/ggml_timestep_embedding tests

* cuda: fix im2col

* add ggml_arange/ggml_timestep_embbeding support for metal backend

* fix some bugs

* fix some bugs

* Update ggml.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* modify according to the review comments

* ggml : fix compile warnings + code style

* ggml : normalize compute_forward calls + fix seg fault in debug

* minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>

* sync : ggml

* add alias for chat template (#5858)

* speculative : implement stochastic speculative sampling (#5625)

* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix #5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README

* cmake : handle cases where git index is not found in .git (#5844)

* Update CMakeLists.txt

* Update CMakeLists.txt

* ggml : introduce ggml_status (ggml/750)

* using enum as an exit code instead of macros

* update return type from enum to unsigned int

* indentation fix

* compound update
ggml_compute_exit_code -> ggml_status
changed ggml_status from a bit-field type to simple codes
ggml_status to string cast

* ggml_status to string cast

* GGML_CALL was removed

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* sync : ggml

ggml-ci

* ggml : fix unknown status (#0)

* flake : fix

* llama : fix embeddings (#5796)

* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list

* nix: static build (#5814)

* fix speculative decoding build on windows (#5874)

* rebase and rm tailing space

---------

Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com>
Co-authored-by: compilade <113953597+compilade@users.noreply.github.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com>
Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com>
Co-authored-by: Black_Fox <radekliska@gmail.com>
Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: DAN™ <dranger003@gmail.com>
Co-authored-by: leejet <leejet714@gmail.com>
Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com>
Co-authored-by: Dane Madsen <dane_madsen@hotmail.com>
Co-authored-by: hutli <6594598+hutli@users.noreply.github.com>
Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
hazelnutcloud pushed a commit to hazelnutcloud/llama.cpp that referenced this pull request Mar 10, 2024
* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix ggerganov#5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README
hazelnutcloud pushed a commit to hazelnutcloud/llama.cpp that referenced this pull request Mar 10, 2024
* fix mul_mat fault in cpy_f32_f16

* rm unused function

* add wait() for memcpy

* restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl

* fix format issue

* llama : fix segfault from unknown model arch name (ggerganov#5820)

* llama : fix segfault from unknown model arch name

* llama : make all LLM maps const

This also requires using `std::map::at` instead of its `operator[]`
which does not exist for const maps.

* llama : name LLM_ARCH_UNKNOWN to "(unknown)"

This avoids errors from `std::map::at` when
getting the general name of the model architecture.
Using "(unknown)" instead of an empty string as per suggestion
ggerganov#5820 (comment)

* llama : remove redundant inner const for LLM_TENSOR_NAMES

The extra const won't do anything here as const maps
return const references to values.

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : remove redundant nullptr check in llm_arch_from_string

Since LLM_ARCH_NAMES is a const map, no spurious elements
with a NULL name are inserted anymore, so this check is dead code.

---------

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : refactor internal quantization functions (ggerganov#5830)

* scripts : add pod-llama.sh

* ggml : IQ3_S improvements (ggerganov#5829)

* iq3_s: somewhat faster AVX2 dot product

On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using
16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s.
PP-512 increases to 28.5 t/s from 23.8 t/s.

* iq3_s: somewhat faster ARM_NEON dot product

Still dog slow - 10.7 t/s up from 9.9 t/s.

* iq3_s: another small ARM_NEON improvement

10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick
that works best on AVX2.

* iq3_s: minor improvement on Metal

49.4 t/s -> 50.3 t/s

* iq3_s: PPL improvement

E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653.

* iq3_s: use new grid everywhere

* Fix ARM_NEON

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

* convert-hf : make model class definitions self-contained (ggerganov#5825)

* convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (ggerganov#5821)

* ggml : fix IQ3_S AVX implementation (ggerganov#5834)

ggml-ci

* llama : add abort_callback to interrupt computation (ggerganov#5409)

* using abort_callback from ggml to stop llama computation

* format fix

* a brief explaining comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: tests: passkey challenge /  self-extend with context shift demo (ggerganov#5832)

* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test

* flake.lock: Update (ggerganov#5842)

Flake lock file updates:

• Updated input 'flake-parts':
    'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01)
  → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01)
• Updated input 'flake-parts/nixpkgs-lib':
    'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29)
• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* server : init http requests thread pool with --parallel if set (ggerganov#5836)

* ci : schedule slow server tests only on Release or on demand (ggerganov#5839)

* llama : fix llama_copy_state_data with fragmented KV cache (ggerganov#5840)

The row size of the saved states was based on kv_self.head while
it should be based on llama_kv_cache_cell_max.

Existing session files should still work.

* llama : fix llama_kv_cache_cell_max inability to return 1

I've also changed its return type to uint32_t,
because this function is always used to set the value of uint32_t variables,
and because the index already has this type.

* llama : fix state size calculation

Some bytes in the state were unaccounted for in llama_get_state_size.
Since the logits reserve so much space, it did not cause problems.

* gguf-dump : support i-quants (ggerganov#5841)

Co-authored-by: Black_Fox <radekliska@gmail.com>

* llama : allow for user specified embedding pooling type (ggerganov#5849)

* allow for user specified pooling type

* llama : use enum types over int

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* readme : add API changes section

* cuda : fix data race in soft max (ggerganov#5853)

* main : support special tokens as reverse/anti prompt (ggerganov#5847)

* Support special tokens as reverse/anti prompt.

* Tokenize antiprompts only once.

* main : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* common : use LLAMA_DEFAULT_SEED (ggerganov#5855)

* add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)

* cuda: fix group_norm

* cuda: add batch inference support for ggml_pad/ggml_upscale

* add ggml_arrange

* add ggml_timestep_embedding

* update ggml_arange/ggml_timestep_embedding tests

* cuda: fix im2col

* add ggml_arange/ggml_timestep_embbeding support for metal backend

* fix some bugs

* fix some bugs

* Update ggml.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* modify according to the review comments

* ggml : fix compile warnings + code style

* ggml : normalize compute_forward calls + fix seg fault in debug

* minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>

* sync : ggml

* add alias for chat template (ggerganov#5858)

* speculative : implement stochastic speculative sampling (ggerganov#5625)

* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix ggerganov#5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README

* cmake : handle cases where git index is not found in .git (ggerganov#5844)

* Update CMakeLists.txt

* Update CMakeLists.txt

* ggml : introduce ggml_status (ggml/750)

* using enum as an exit code instead of macros

* update return type from enum to unsigned int

* indentation fix

* compound update
ggml_compute_exit_code -> ggml_status
changed ggml_status from a bit-field type to simple codes
ggml_status to string cast

* ggml_status to string cast

* GGML_CALL was removed

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* sync : ggml

ggml-ci

* ggml : fix unknown status (#0)

* flake : fix

* llama : fix embeddings (ggerganov#5796)

* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list

* nix: static build (ggerganov#5814)

* fix speculative decoding build on windows (ggerganov#5874)

* rebase and rm tailing space

---------

Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com>
Co-authored-by: compilade <113953597+compilade@users.noreply.github.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com>
Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com>
Co-authored-by: Black_Fox <radekliska@gmail.com>
Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: DAN™ <dranger003@gmail.com>
Co-authored-by: leejet <leejet714@gmail.com>
Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com>
Co-authored-by: Dane Madsen <dane_madsen@hotmail.com>
Co-authored-by: hutli <6594598+hutli@users.noreply.github.com>
Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix ggerganov#5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README
jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Mar 13, 2024
* fix mul_mat fault in cpy_f32_f16

* rm unused function

* add wait() for memcpy

* restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl

* fix format issue

* llama : fix segfault from unknown model arch name (ggerganov#5820)

* llama : fix segfault from unknown model arch name

* llama : make all LLM maps const

This also requires using `std::map::at` instead of its `operator[]`
which does not exist for const maps.

* llama : name LLM_ARCH_UNKNOWN to "(unknown)"

This avoids errors from `std::map::at` when
getting the general name of the model architecture.
Using "(unknown)" instead of an empty string as per suggestion
ggerganov#5820 (comment)

* llama : remove redundant inner const for LLM_TENSOR_NAMES

The extra const won't do anything here as const maps
return const references to values.

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : remove redundant nullptr check in llm_arch_from_string

Since LLM_ARCH_NAMES is a const map, no spurious elements
with a NULL name are inserted anymore, so this check is dead code.

---------

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : refactor internal quantization functions (ggerganov#5830)

* scripts : add pod-llama.sh

* ggml : IQ3_S improvements (ggerganov#5829)

* iq3_s: somewhat faster AVX2 dot product

On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using
16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s.
PP-512 increases to 28.5 t/s from 23.8 t/s.

* iq3_s: somewhat faster ARM_NEON dot product

Still dog slow - 10.7 t/s up from 9.9 t/s.

* iq3_s: another small ARM_NEON improvement

10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick
that works best on AVX2.

* iq3_s: minor improvement on Metal

49.4 t/s -> 50.3 t/s

* iq3_s: PPL improvement

E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653.

* iq3_s: use new grid everywhere

* Fix ARM_NEON

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

* convert-hf : make model class definitions self-contained (ggerganov#5825)

* convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (ggerganov#5821)

* ggml : fix IQ3_S AVX implementation (ggerganov#5834)

ggml-ci

* llama : add abort_callback to interrupt computation (ggerganov#5409)

* using abort_callback from ggml to stop llama computation

* format fix

* a brief explaining comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: tests: passkey challenge /  self-extend with context shift demo (ggerganov#5832)

* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test

* flake.lock: Update (ggerganov#5842)

Flake lock file updates:

• Updated input 'flake-parts':
    'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01)
  → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01)
• Updated input 'flake-parts/nixpkgs-lib':
    'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29)
• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* server : init http requests thread pool with --parallel if set (ggerganov#5836)

* ci : schedule slow server tests only on Release or on demand (ggerganov#5839)

* llama : fix llama_copy_state_data with fragmented KV cache (ggerganov#5840)

The row size of the saved states was based on kv_self.head while
it should be based on llama_kv_cache_cell_max.

Existing session files should still work.

* llama : fix llama_kv_cache_cell_max inability to return 1

I've also changed its return type to uint32_t,
because this function is always used to set the value of uint32_t variables,
and because the index already has this type.

* llama : fix state size calculation

Some bytes in the state were unaccounted for in llama_get_state_size.
Since the logits reserve so much space, it did not cause problems.

* gguf-dump : support i-quants (ggerganov#5841)

Co-authored-by: Black_Fox <radekliska@gmail.com>

* llama : allow for user specified embedding pooling type (ggerganov#5849)

* allow for user specified pooling type

* llama : use enum types over int

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* readme : add API changes section

* cuda : fix data race in soft max (ggerganov#5853)

* main : support special tokens as reverse/anti prompt (ggerganov#5847)

* Support special tokens as reverse/anti prompt.

* Tokenize antiprompts only once.

* main : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* common : use LLAMA_DEFAULT_SEED (ggerganov#5855)

* add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)

* cuda: fix group_norm

* cuda: add batch inference support for ggml_pad/ggml_upscale

* add ggml_arrange

* add ggml_timestep_embedding

* update ggml_arange/ggml_timestep_embedding tests

* cuda: fix im2col

* add ggml_arange/ggml_timestep_embbeding support for metal backend

* fix some bugs

* fix some bugs

* Update ggml.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* modify according to the review comments

* ggml : fix compile warnings + code style

* ggml : normalize compute_forward calls + fix seg fault in debug

* minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>

* sync : ggml

* add alias for chat template (ggerganov#5858)

* speculative : implement stochastic speculative sampling (ggerganov#5625)

* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix ggerganov#5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README

* cmake : handle cases where git index is not found in .git (ggerganov#5844)

* Update CMakeLists.txt

* Update CMakeLists.txt

* ggml : introduce ggml_status (ggml/750)

* using enum as an exit code instead of macros

* update return type from enum to unsigned int

* indentation fix

* compound update
ggml_compute_exit_code -> ggml_status
changed ggml_status from a bit-field type to simple codes
ggml_status to string cast

* ggml_status to string cast

* GGML_CALL was removed

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* sync : ggml

ggml-ci

* ggml : fix unknown status (#0)

* flake : fix

* llama : fix embeddings (ggerganov#5796)

* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list

* nix: static build (ggerganov#5814)

* fix speculative decoding build on windows (ggerganov#5874)

* rebase and rm tailing space

---------

Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com>
Co-authored-by: compilade <113953597+compilade@users.noreply.github.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com>
Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com>
Co-authored-by: Black_Fox <radekliska@gmail.com>
Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: DAN™ <dranger003@gmail.com>
Co-authored-by: leejet <leejet714@gmail.com>
Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com>
Co-authored-by: Dane Madsen <dane_madsen@hotmail.com>
Co-authored-by: hutli <6594598+hutli@users.noreply.github.com>
Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix ggerganov#5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* fix mul_mat fault in cpy_f32_f16

* rm unused function

* add wait() for memcpy

* restore ci/run.sh, rename struct defination, fix bug in ggml_sycl_op_mul_mat_sycl

* fix format issue

* llama : fix segfault from unknown model arch name (ggerganov#5820)

* llama : fix segfault from unknown model arch name

* llama : make all LLM maps const

This also requires using `std::map::at` instead of its `operator[]`
which does not exist for const maps.

* llama : name LLM_ARCH_UNKNOWN to "(unknown)"

This avoids errors from `std::map::at` when
getting the general name of the model architecture.
Using "(unknown)" instead of an empty string as per suggestion
ggerganov#5820 (comment)

* llama : remove redundant inner const for LLM_TENSOR_NAMES

The extra const won't do anything here as const maps
return const references to values.

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : remove redundant nullptr check in llm_arch_from_string

Since LLM_ARCH_NAMES is a const map, no spurious elements
with a NULL name are inserted anymore, so this check is dead code.

---------

Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>

* llama : refactor internal quantization functions (ggerganov#5830)

* scripts : add pod-llama.sh

* ggml : IQ3_S improvements (ggerganov#5829)

* iq3_s: somewhat faster AVX2 dot product

On Ryzen a 7950X TG-128 increases to 16 t/s from 15.5 t/s using
16 threads. For 8 threads it is 13.85 t/s vs 11.75 t/s.
PP-512 increases to 28.5 t/s from 23.8 t/s.

* iq3_s: somewhat faster ARM_NEON dot product

Still dog slow - 10.7 t/s up from 9.9 t/s.

* iq3_s: another small ARM_NEON improvement

10.7 -> 11.0 t/s. Using vmulq_s8 is faster than the xor - sub trick
that works best on AVX2.

* iq3_s: minor improvement on Metal

49.4 t/s -> 50.3 t/s

* iq3_s: PPL improvement

E.g., for a context of 4096 LLaMA-v2-7B goes to 5.1340 from 5.1653.

* iq3_s: use new grid everywhere

* Fix ARM_NEON

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>

* convert-hf : make model class definitions self-contained (ggerganov#5825)

* convert : automatically fall back to HfVocab if tokenizer.model doesn't exist (ggerganov#5821)

* ggml : fix IQ3_S AVX implementation (ggerganov#5834)

ggml-ci

* llama : add abort_callback to interrupt computation (ggerganov#5409)

* using abort_callback from ggml to stop llama computation

* format fix

* a brief explaining comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: tests: passkey challenge /  self-extend with context shift demo (ggerganov#5832)

* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test

* flake.lock: Update (ggerganov#5842)

Flake lock file updates:

• Updated input 'flake-parts':
    'github:hercules-ci/flake-parts/b253292d9c0a5ead9bc98c4e9a26c6312e27d69f' (2024-02-01)
  → 'github:hercules-ci/flake-parts/f7b3c975cf067e56e7cda6cb098ebe3fb4d74ca2' (2024-03-01)
• Updated input 'flake-parts/nixpkgs-lib':
    'github:NixOS/nixpkgs/97b17f32362e475016f942bbdfda4a4a72a8a652?dir=lib' (2024-01-29)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8?dir=lib' (2024-02-29)
• Updated input 'nixpkgs':
    'github:NixOS/nixpkgs/cbc4211f0afffe6dfd2478a62615dd5175a13f9a' (2024-02-23)
  → 'github:NixOS/nixpkgs/1536926ef5621b09bba54035ae2bb6d806d72ac8' (2024-02-29)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* server : init http requests thread pool with --parallel if set (ggerganov#5836)

* ci : schedule slow server tests only on Release or on demand (ggerganov#5839)

* llama : fix llama_copy_state_data with fragmented KV cache (ggerganov#5840)

The row size of the saved states was based on kv_self.head while
it should be based on llama_kv_cache_cell_max.

Existing session files should still work.

* llama : fix llama_kv_cache_cell_max inability to return 1

I've also changed its return type to uint32_t,
because this function is always used to set the value of uint32_t variables,
and because the index already has this type.

* llama : fix state size calculation

Some bytes in the state were unaccounted for in llama_get_state_size.
Since the logits reserve so much space, it did not cause problems.

* gguf-dump : support i-quants (ggerganov#5841)

Co-authored-by: Black_Fox <radekliska@gmail.com>

* llama : allow for user specified embedding pooling type (ggerganov#5849)

* allow for user specified pooling type

* llama : use enum types over int

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* readme : add API changes section

* cuda : fix data race in soft max (ggerganov#5853)

* main : support special tokens as reverse/anti prompt (ggerganov#5847)

* Support special tokens as reverse/anti prompt.

* Tokenize antiprompts only once.

* main : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* common : use LLAMA_DEFAULT_SEED (ggerganov#5855)

* add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)

* cuda: fix group_norm

* cuda: add batch inference support for ggml_pad/ggml_upscale

* add ggml_arrange

* add ggml_timestep_embedding

* update ggml_arange/ggml_timestep_embedding tests

* cuda: fix im2col

* add ggml_arange/ggml_timestep_embbeding support for metal backend

* fix some bugs

* fix some bugs

* Update ggml.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* modify according to the review comments

* ggml : fix compile warnings + code style

* ggml : normalize compute_forward calls + fix seg fault in debug

* minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>

* sync : ggml

* add alias for chat template (ggerganov#5858)

* speculative : implement stochastic speculative sampling (ggerganov#5625)

* (WIP) Implement stochastic speculative decoding

* sample from residual distribution on draft accept failure

* fix ggerganov#5657: force greedy sampling with probs when temp is 0

* remove p_accept parameter

* fix style

* remove unused variables

* add srand() in speculative.cpp

* replace use of rand() with mt19937 sampling

* fixes based on review (@JohannesGaessler)

* fix r random generation

* randomly select next sequence to verify + fix bug in memory freeing

* fix bug in active_seqs sync

* fix uniform int distribution initialization

* remove warnings from comparison between int and size_t

* check grammar in `llama_sample_probability_distribution_impl`

* remove malloc code by utilizing vectors

* add PR link to README

* cmake : handle cases where git index is not found in .git (ggerganov#5844)

* Update CMakeLists.txt

* Update CMakeLists.txt

* ggml : introduce ggml_status (ggml/750)

* using enum as an exit code instead of macros

* update return type from enum to unsigned int

* indentation fix

* compound update
ggml_compute_exit_code -> ggml_status
changed ggml_status from a bit-field type to simple codes
ggml_status to string cast

* ggml_status to string cast

* GGML_CALL was removed

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* sync : ggml

ggml-ci

* ggml : fix unknown status (#0)

* flake : fix

* llama : fix embeddings (ggerganov#5796)

* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list

* nix: static build (ggerganov#5814)

* fix speculative decoding build on windows (ggerganov#5874)

* rebase and rm tailing space

---------

Co-authored-by: LiangtaoJin <liang-tao.jin@intel.com>
Co-authored-by: compilade <113953597+compilade@users.noreply.github.com>
Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com>
Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Kawrakow <48489457+ikawrakow@users.noreply.github.com>
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Michael Podvitskiy <podvitskiymichael@gmail.com>
Co-authored-by: Pierrick Hymbert <pierrick.hymbert@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Nindaleth <Nindaleth@users.noreply.github.com>
Co-authored-by: Black_Fox <radekliska@gmail.com>
Co-authored-by: Douglas Hanley <thesecretaryofwar@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: DAN™ <dranger003@gmail.com>
Co-authored-by: leejet <leejet714@gmail.com>
Co-authored-by: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com>
Co-authored-by: Dane Madsen <dane_madsen@hotmail.com>
Co-authored-by: hutli <6594598+hutli@users.noreply.github.com>
Co-authored-by: Jeffrey Quesnelle <emozilla@nousresearch.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
need feedback Testing and feedback with results are needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Stochastic Sampling for Speculative Decoding example
4 participants