-
Notifications
You must be signed in to change notification settings - Fork 264
Expand file tree
/
Copy pathconstrained_logits_processor.cpp
More file actions
270 lines (245 loc) · 11.7 KB
/
constrained_logits_processor.cpp
File metadata and controls
270 lines (245 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <memory>
#include <sstream>
#include <string>
#include <sys/types.h>
#include "generators.h"
#include "models/model.h"
#if USE_GUIDANCE
#include "llguidance.h"
#endif
#include "constrained_logits_processor.h"
namespace Generators {
#if USE_GUIDANCE
GuidanceLogitsProcessor::GuidanceLogitsProcessor(const State& state)
: params_(state.params_),
eos_token_(state.params_->config.model.eos_token_id[0]) {
if (params_->guidance_type.empty() || params_->guidance_type.empty()) {
throw std::runtime_error("Guidance type and data must be provided together");
}
if (params_->guidance_type != "json_schema" && params_->guidance_type != "regex" && params_->guidance_type != "lark_grammar") {
throw std::runtime_error("Unsupported guidance type: " + std::string(params_->guidance_type) + " (only json_schema, regex and lark_grammar are supported)");
}
auto tokenize_fn = (LlgTokenizeFn) + [](const void* user_data, const uint8_t* bytes,
size_t bytes_len, uint32_t* output_tokens, size_t output_tokens_len)
-> unsigned long {
const TokenizeData* tokenize_data = reinterpret_cast<const TokenizeData*>(user_data);
auto output_ids = tokenize_partial(reinterpret_cast<const Tokenizer*>(tokenize_data->tokenizer), tokenize_data->prefix_len, bytes, bytes_len);
size_t output_size = std::min(output_tokens_len, output_ids.size());
for (size_t i = 0; i < output_size; i++) {
output_tokens[i] = output_ids[i];
}
return static_cast<unsigned long>(output_ids.size());
};
auto tokenizer_path = state.params_->config.config_path.string();
fs::path tokenizer_path_fs(tokenizer_path);
fs::path json_path(tokenizer_path_fs / kDefaultVocabFile);
std::ifstream json_file(json_path.string());
std::stringstream json_buffer;
json_buffer << json_file.rdbuf();
std::string json_data = json_buffer.str();
tokenizer_ = state.model_.CreateTokenizer();
auto prefix_len = tokenizer_->Encode(kTokenizePrefixStr).size();
tokenize_data_ = {tokenizer_.get(), prefix_len};
LlgTokenizerInit tokenizer_init = {
static_cast<uint32_t>(params_->config.model.vocab_size), // vocab_size
eos_token_, // eos_token
nullptr, // token_lens
nullptr, // token_bytes
json_data.c_str(), // tokenizer_json config data
false, // tokenize_assumes_string
tokenize_fn, // tokenize_fn
false, // use_approximate_greedy_tokenize_fn
&tokenize_data_, // user_data
};
char error_buf[256];
llg_tokenizer_ = std::unique_ptr<LlgTokenizer, LlgTokenizerDeleter>(llg_new_tokenizer(&tokenizer_init, error_buf, sizeof(error_buf)));
if (!llg_tokenizer_) {
throw std::runtime_error("Error creating llg_tokenizer: " + std::string(error_buf));
}
llg_constraints_.resize(params_->search.batch_size);
for (int i = 0; i < params_->search.batch_size; i++) {
LlgConstraintInit constraint_init;
llg_constraint_init_set_defaults(&constraint_init, llg_tokenizer_.get());
constraint_init.ff_tokens_ok = params_->guidance_ff_tokens_enabled && params_->search.batch_size == 1 && params_->search.num_beams == 1;
LlgConstraint* constraint_ptr = nullptr;
if (params_->guidance_type == "json_schema") {
constraint_ptr = llg_new_constraint_json(&constraint_init, params_->guidance_data.data());
} else if (params_->guidance_type == "regex") {
constraint_ptr = llg_new_constraint_regex(&constraint_init, params_->guidance_data.data());
} else if (params_->guidance_type == "lark_grammar") {
constraint_ptr = llg_new_constraint_lark(&constraint_init, params_->guidance_data.data());
}
if (llg_get_error(constraint_ptr) != nullptr) {
std::string error_message = llg_get_error(constraint_ptr);
llg_free_constraint(constraint_ptr);
throw std::runtime_error("Error creating grammar: " + error_message);
}
llg_constraints_[i] = std::unique_ptr<LlgConstraint, LlgConstraintDeleter>(constraint_ptr);
// create ff_tokens buffer for each batch item
ff_tokens_batch_.push_back(std::vector<int32_t>());
}
// Compute the mask asynchronously to avoid blocking the model inference on device
mask_future_ = std::async(std::launch::async, [&]() {
return ComputeMask();
});
}
std::vector<std::vector<uint32_t>> GuidanceLogitsProcessor::ComputeMask() {
std::vector<std::vector<uint32_t>> masks;
for (int batch_idx = 0; batch_idx < params_->search.batch_size; batch_idx++) {
LlgMaskResult mask_result;
auto error = llg_compute_mask(llg_constraints_[batch_idx].get(), &mask_result);
if (error != 0) {
// If the mask computation fails, we need to reset the constraint
// and try again. LLGuidance needs to be reset for every new prompt.
ResetWithoutCompute();
auto retry_error = llg_compute_mask(llg_constraints_[batch_idx].get(), &mask_result);
if (retry_error != 0) {
std::string error_message = llg_get_error(llg_constraints_[batch_idx].get());
throw std::runtime_error("Error computing mask: " + error_message);
}
}
std::vector<uint32_t> mask;
if (mask_result.is_stop) {
// when logits processor decides to stop, we mask all tokens except the EOS token
mask = std::vector<uint32_t>((params_->config.model.vocab_size - 1) / 32 + 1, 0);
uint32_t eos_mask32 = 1 << (eos_token_ % 32);
mask[eos_token_ / 32] = eos_mask32;
} else {
mask.reserve((params_->config.model.vocab_size - 1) / 32 + 1);
for (int i = 0; i < (params_->config.model.vocab_size - 1) / 32 + 1; i++) {
mask.push_back(mask_result.sample_mask[i]);
}
}
masks.push_back(mask);
}
return masks;
}
std::vector<int32_t> GuidanceLogitsProcessor::GetFFTokens(size_t index) {
if (index >= ff_tokens_batch_.size()) {
// in case guidance is not being used, return empty vector
return std::vector<int32_t>();
}
auto v = std::vector<int32_t>(ff_tokens_batch_[index]);
ff_tokens_batch_[index].clear();
return v;
}
void GuidanceLogitsProcessor::CommitTokens(std::span<int32_t> tokens) {
for (int i = 0; i < params_->search.batch_size; i++) {
LlgCommitResult commit_result;
auto error = llg_commit_token(llg_constraints_[i].get(), static_cast<uint32_t>(tokens[i]), &commit_result);
if (error != 0) {
std::string error_message = llg_get_error(llg_constraints_[i].get());
throw std::runtime_error("Error committing tokens: " + error_message);
}
auto& ff_tokens = ff_tokens_batch_[i];
ff_tokens.clear();
// Store forced tokens (i.e. index >= 1) to process outside of this logits processor
for (size_t j = 1; j < commit_result.n_tokens; j++) {
ff_tokens.push_back((int32_t)commit_result.tokens[j]);
}
}
mask_future_ = std::async(std::launch::async, [&]() {
return ComputeMask();
});
masks_.clear();
}
std::vector<std::vector<uint32_t>> GuidanceLogitsProcessor::GetMask() {
if (masks_.empty()) {
masks_ = mask_future_.get();
}
return masks_;
}
void GuidanceLogitsProcessor::ProcessLogits(DeviceSpan<float> logits) {
auto masks = GetMask();
if (params_->p_device->GetType() == DeviceType::CUDA) {
const size_t words_per_row = params_->config.model.vocab_size / 32;
const size_t total_words = masks.size() * words_per_row;
std::vector<uint32_t> flat_masks(total_words);
uint32_t* dst = flat_masks.data();
for (const auto& row : masks) {
std::memcpy(dst, row.data(), words_per_row * sizeof(uint32_t));
dst += words_per_row;
}
auto cuda_logits_mask_ptr_ = params_->p_device->Allocate<uint32_t>(total_words);
copy(std::span<const uint32_t>{flat_masks}, cuda_logits_mask_ptr_.CpuSpan());
cuda_logits_mask_ptr_.CopyCpuToDevice();
params_->p_device->LaunchAddLogitsMask(logits.Span().data(), params_->search.batch_size, params_->config.model.vocab_size, cuda_logits_mask_ptr_.Span().data());
return;
}
size_t vocab_index = 0;
auto logits_span = logits.CpuSpan();
for (int index = 0; index < params_->search.batch_size; index++) {
auto subspan = logits_span.subspan(vocab_index, params_->config.model.vocab_size);
auto& mask = masks[index];
for (size_t i = 0; i < params_->config.model.vocab_size; i++) {
// mask is a 32-bit integer, where each bit corresponds to a token in the vocabulary.
// If the bit is set, the corresponding token is masked (i.e., its logit is set to the lowest possible value).
subspan[i] = mask[i / 32] & (1 << (i % 32)) ? subspan[i] : std::numeric_limits<float>::lowest();
}
vocab_index += params_->config.model.vocab_size;
}
}
void GuidanceLogitsProcessor::ResetWithoutCompute() {
masks_.clear();
llg_constraints_.clear();
llg_constraints_.resize(params_->search.batch_size);
for (int i = 0; i < params_->search.batch_size; i++) {
LlgConstraintInit constraint_init;
llg_constraint_init_set_defaults(&constraint_init, llg_tokenizer_.get());
LlgConstraint* constraint_ptr;
if (params_->guidance_type == "json_schema") {
constraint_ptr = llg_new_constraint_json(&constraint_init, params_->guidance_data.data());
} else if (params_->guidance_type == "regex") {
constraint_ptr = llg_new_constraint_regex(&constraint_init, params_->guidance_data.data());
} else if (params_->guidance_type == "lark_grammar") {
constraint_ptr = llg_new_constraint_lark(&constraint_init, params_->guidance_data.data());
} else {
throw std::runtime_error("Unsupported guidance type: " + std::string(params_->guidance_type) + " (only json_schema, regex and lark_grammar are supported)");
}
if (llg_get_error(constraint_ptr) != nullptr) {
std::string error_message = llg_get_error(constraint_ptr);
llg_free_constraint(constraint_ptr);
throw std::runtime_error("Error creating grammar: " + error_message);
}
llg_constraints_[i] = std::unique_ptr<LlgConstraint, LlgConstraintDeleter>(constraint_ptr);
}
for (int i = 0; i < ff_tokens_batch_.size(); i++) {
ff_tokens_batch_[i].clear();
}
}
// Reset the masks and llguidance constraints and then recompute the mask
void GuidanceLogitsProcessor::Reset() {
ResetWithoutCompute();
mask_future_ = std::async(std::launch::async, [&]() {
return ComputeMask();
});
}
std::vector<int32_t> GuidanceLogitsProcessor::tokenize_partial(const Tokenizer* tokenizer, const size_t prefix_len,
const uint8_t* bytes, size_t bytes_len) {
// add prefix to tokenize for partial tokenization, it will produce ids more stable
std::string input_string = kTokenizePrefixStr;
input_string.reserve(bytes_len + 2);
for (size_t i = 0; i < bytes_len; i++) {
input_string.push_back(bytes[i]);
}
std::vector<int32_t> output_ids = tokenizer->Encode(input_string.c_str());
return std::vector<int32_t>(output_ids.begin() + prefix_len, output_ids.end());
}
#endif
std::unique_ptr<ConstrainedLogitsProcessor> CreateGuidanceLogitsProcessor(const State& state) {
if (!state.params_->guidance_type.empty() && !state.params_->guidance_data.empty()) {
#if USE_GUIDANCE
return std::make_unique<GuidanceLogitsProcessor>(state);
#else
if (g_log.enabled)
Log("warning", "No supported ConstrainedLogitsProcessor found. To use guidance, build with use_guidance=true");
#endif
}
return nullptr;
}
} // namespace Generators