-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
40 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,67 @@ | ||
const std = @import("std"); | ||
|
||
// struct used when sorting probabilities during top-p sampling | ||
pub const ProbIndex = struct { prob: f32, index: usize }; | ||
pub const ProbabilityIndexPair = struct { | ||
probability: f32, | ||
index: usize, | ||
}; | ||
|
||
// The Curious Case of Neural Text Degeneration (https://arxiv.org/abs/1904.09751) | ||
pub fn sampleNucleus( | ||
rng_value: f32, | ||
probabilities: []const f32, | ||
probability_distribution: []const f32, | ||
top_p: f32, | ||
prob_indices_buffer: []ProbIndex, | ||
probability_index_pairs_buffer: []ProbabilityIndexPair, | ||
) usize { | ||
@setFloatMode(.Optimized); | ||
|
||
// top-p sampling (or "nucleus sampling") samples from the smallest set of | ||
// tokens that exceed probability topp. This way we never sample tokens that | ||
// have very low probabilities and are less likely to go "off the rails". | ||
std.debug.assert(probability_distribution.len > 0); | ||
|
||
// elements smaller than (1 - top_p) / (probabilities.len - 1) cannot be part of the result | ||
// and can be filtered out directly | ||
// https://github.com/karpathy/llama2.c/commit/d421a95b2bfe593b2d9e5c147f3efc8d128afe0e | ||
const cutoff: f32 = (1 - top_p) / @as(f32, @floatFromInt(probabilities.len - 1)); | ||
var probability_threshold: f32 = | ||
(1 - top_p) / @as(f32, @floatFromInt(probability_distribution.len - 1)); | ||
|
||
var n0: usize = 0; | ||
var n_probability_index_pairs: usize = 0; | ||
|
||
for (probabilities, 0..) |prob, index| { | ||
if (prob >= cutoff) { | ||
prob_indices_buffer[n0].prob = prob; | ||
prob_indices_buffer[n0].index = index; | ||
n0 += 1; | ||
for (probability_distribution, 0..) |probability, index| { | ||
if (probability_threshold < probability) { | ||
probability_index_pairs_buffer[n_probability_index_pairs].probability = probability; | ||
probability_index_pairs_buffer[n_probability_index_pairs].index = index; | ||
n_probability_index_pairs += 1; | ||
} | ||
} | ||
|
||
var filtered_prob_indices = prob_indices_buffer[0..n0]; | ||
var probability_index_pairs = probability_index_pairs_buffer[0..n_probability_index_pairs]; | ||
|
||
// sort indices in descending order of probabilities | ||
std.sort.block(ProbIndex, filtered_prob_indices, {}, lessThan); | ||
std.sort.block(ProbabilityIndexPair, probability_index_pairs, {}, lessThan); | ||
|
||
// truncate the list where cumulative probability exceeds topp | ||
var cumulative_probability: f32 = 0; | ||
var truncated_prob_indices: ?[]ProbIndex = null; | ||
|
||
for (filtered_prob_indices, 0..) |prob_index, index| { | ||
cumulative_probability += prob_index.prob; | ||
for (probability_index_pairs, 0..) |probability_index_pair, index| { | ||
cumulative_probability += probability_index_pair.probability; | ||
|
||
if (cumulative_probability > top_p) { | ||
truncated_prob_indices = filtered_prob_indices[0..(index + 1)]; | ||
probability_index_pairs = probability_index_pairs[0..(index + 1)]; | ||
|
||
break; // we've exceeded topp by including index | ||
break; | ||
} | ||
} | ||
|
||
// sample from the truncated list | ||
const probability_threshold = rng_value * cumulative_probability; | ||
|
||
probability_threshold = rng_value * cumulative_probability; | ||
cumulative_probability = 0; | ||
|
||
if (truncated_prob_indices) |prob_indices| { | ||
for (prob_indices) |prob_index| { | ||
cumulative_probability += prob_index.prob; | ||
for (probability_index_pairs) |probability_index_pair| { | ||
cumulative_probability += probability_index_pair.probability; | ||
|
||
if (probability_threshold < cumulative_probability) { | ||
return prob_index.index; | ||
} | ||
if (probability_threshold < cumulative_probability) { | ||
return probability_index_pair.index; | ||
} | ||
} | ||
|
||
return filtered_prob_indices[filtered_prob_indices.len - 1].index; | ||
return probability_index_pairs[probability_index_pairs.len - 1].index; | ||
} | ||
|
||
fn lessThan(context: void, lhs: ProbIndex, rhs: ProbIndex) bool { | ||
fn lessThan(context: void, lhs: ProbabilityIndexPair, rhs: ProbabilityIndexPair) bool { | ||
_ = context; | ||
|
||
return rhs.prob < lhs.prob; | ||
return rhs.probability < lhs.probability; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters