Skip to content

Commit fdb6f26

Browse files
pi6amNexesenex
authored andcommitted
Post-Review DRY dynamic N-gram sampler (LostRuins#982)
* Add the DRY dynamic N-gram anti-repetition sampler The DRY (Do not Repeat Yourself) sampler is a dynamic N-gram repetition penalty that negatively scores tokens that would extend sequences that already appear in the context. See this discussion for a motivation and explanation of the sampler: oobabooga/text-generation-webui#5677 This implementation of DRY mostly aligns with the obabooga version with a few modifications. It uses a more efficient linear scanning algorithm to identify repetitions. It also supports multi-token sequence breakers. As a limitation, this implementation reuses the rep pen range parameter, rather than introducing a new range just for the DRY sampler. There is a separate change to lite.koboldai.net that exposes the DRY sampler parameters to KoboldAI Lite, so none of the embed files have been changed as part of this commit. * Update default DRY parameters to match lite * Improve DRY token debug logging * Replace `and` with `&&` to fix MSVC compile error Little known fact: The C++98 standard defines `and` as an alternative token for the `&&` operator (along with a bunch of other digraphs). MSVC does not allow these without using the /Za option or including the <iso646.h> header. Change to the more standard operator to make this code more portable. * Fix MSVC compile error because log is not constexpr Replace the compile-time computation with a floating-point approximation of log(std::numeric_limits<float>::max()). * Remove unused llama sampler variables and clean up sequence breakers. * Remove KCPP_SAMPLER_DRY as a separate enum entry The DRY sampler is effectively a repetition penalty and there are very few reasons to apply it at a different place in sampler order than the standard single-token penalty. There are also multiple projects that have dependencies on the existing sampler IDs, including KoboldAI, KoboldAI Lite, and Silly Tavern. In order to minimize the impact of the dependencies of adding the DRY sampler to koboldcpp, it makes the most sense to not add a new ID for now, and instead to piggyback on KCPP_SAMPLER_REP_PEN. In the future if we find a use case for splitting the application of rep pen and DRY we can introduce a new enum entry then. * Add the dry_penalty_last_n to independently control DRY penalty range This parameter follows the oobabooga semantics: it's optional, with a default value of zero. Zero means that DRY should sample the entire context. Otherwise, it's the number of tokens from the end of the context that are scanned for repetitions. * Limit sequence breaker lengths in tokens and characters The core DRY sampler algorithm is linear in the context length, but there are several parts of the sampler related to multi-token sequence breakers that are potentially quadratic. Without any restrictions, a suitably crafted context and sequence breaker could result in a denial-of-service attack on a server running koboldcpp. This change limits the maximum number of characters and the maximum token length of a sequence breaker in order to limit the maximum overhead associated with the sampler. This change also improves some comments, adding more detail and changing the wording to increase clarity.
1 parent 582931e commit fdb6f26

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

gpttype_adapter.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ static void print_tok_vec_str(std::vector<int> &vec)
320320
// instance, the token '.\n' to be a head for both '.' and '\n'. However if a head token
321321
// begins a multi-token sequence, the head can only extend past `str` at the beginning. The
322322
// tail tokens are generated by tokenizing the remainder.
323-
static void GetOverlappingTokenSequences(const std::string& str, std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& token_sequences) {
323+
// If max_tail_len is >= 0, the maximum token length of a tail sequence is clamped to this value.
324+
static void GetOverlappingTokenSequences(const std::string& str, std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& token_sequences, int max_tail_len = -1) {
324325
for(int v=0;v<n_vocab;++v)
325326
{
326327
std::string word = FileFormatTokenizeID(v, file_format, true);
@@ -341,7 +342,8 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
341342
}
342343
} else {
343344
// Check whether a prefix of the string overlaps with a suffix of the token.
344-
// Just do a naive O(N^2) search.
345+
// Just do a naive O(N^2) search, since the worst case is limited by the
346+
// maximum character length of a token in the vocabulary.
345347
size_t word_len = word.size(), str_len = str.size();
346348
size_t pos = -1;
347349
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
@@ -358,6 +360,9 @@ static void GetOverlappingTokenSequences(const std::string& str, std::unordered_
358360
// there must be trailing letters in `str`.
359361
std::vector<gpt_vocab::id> tokenization;
360362
TokenizeString(str.substr(i), tokenization, file_format, false);
363+
if (max_tail_len >= 0 && tokenization.size() > max_tail_len) {
364+
tokenization.resize(max_tail_len);
365+
}
361366

362367
// Ensure we don't already have a duplicate matching tokenization.
363368
auto its = token_sequences.equal_range(v);
@@ -530,8 +535,9 @@ void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float pe
530535
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
531536
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
532537
//
533-
// This is worst-case O(N^2) for perverse restart sequences, but typically will be O(N) since
534-
// most restart sequences are a single token and we use a hash table to check for head token.
538+
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
539+
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
540+
// With clamping, this scan is O(N) in the context length.
535541

536542
int rep_limit = last_n_repeat;
537543
for (size_t i = 0; i < last_n_repeat; ++i) {
@@ -582,11 +588,17 @@ void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float pe
582588
// The code below is adapted from the public domain implementation by the same author here:
583589
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
584590
//
585-
// This step is worst case O(N), since the Z-algorithm is linear.
586-
//
587591
// Example:
588592
// Last N tokens: a b c c b c y a b c
589593
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
594+
// ^
595+
// This `3` means that the last three tokens of the context (a b c) also appear here.
596+
//
597+
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
598+
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
599+
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
600+
// ensure that the inner while loops only examine each token in the context once as the outer
601+
// for loop iterates over the context.
590602

591603
{
592604
const int last = last_n_repeat - 1;
@@ -2119,11 +2131,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
21192131
}
21202132
dry_sequence_breakers.clear();
21212133
if(kcpp_params->dry_sequence_breakers.size()>0) {
2134+
// Restrict the maximum length of sequences used as sequence breakers. There are
2135+
// very few use cases for a long sequence breaker, and limiting the max length
2136+
// prevents a potential denial of service attack in which long repetitive sequence
2137+
// breakers could result in slow DRY sampling with a suitably crafted context.
2138+
const int MAX_CHAR_LEN = 60;
2139+
const int MAX_SEQ_LEN = 30;
2140+
21222141
if(debugmode==1) {
21232142
printf("\nProcessing %zu dry break strings...",kcpp_params->dry_sequence_breakers.size());
21242143
}
2125-
for (const auto& sequence_break: kcpp_params->dry_sequence_breakers) {
2126-
GetOverlappingTokenSequences(sequence_break, dry_sequence_breakers);
2144+
for (auto sequence_break: kcpp_params->dry_sequence_breakers) {
2145+
if (sequence_break.size() > MAX_CHAR_LEN) {
2146+
sequence_break.resize(MAX_CHAR_LEN);
2147+
}
2148+
GetOverlappingTokenSequences(sequence_break, dry_sequence_breakers, MAX_SEQ_LEN);
21272149
}
21282150
if(debugmode==1) {
21292151
int trivial = 0, non_trivial = 0;
@@ -2659,10 +2681,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
26592681
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n,
26602682
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);
26612683

2662-
if (llama_ctx_v4) {
2663-
empcats_step_post(llama_ctx_v4, id );
2664-
}
2665-
26662684
if (grammar != nullptr) {
26672685
grammar_accept_token(file_format, n_vocab, grammar, id);
26682686
}

0 commit comments

Comments
 (0)