Skip to content

Commit

Permalink
Refactor nucleus sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 23, 2023
1 parent 7dfa4e9 commit 75f8105
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pub const random = @import("lib/random.zig").random;
pub const rmsnorm = @import("lib/rmsnorm.zig").rmsnorm;
pub const rope = @import("lib/rope.zig").rope;
pub const sampleMultinomial = @import("lib/sample_multinomial.zig").sampleMultinomial;
pub const ProbIndex = @import("lib/sample_nucleus.zig").ProbIndex;
pub const ProbabilityIndexPair = @import("lib/sample_nucleus.zig").ProbabilityIndexPair;
pub const sampleNucleus = @import("lib/sample_nucleus.zig").sampleNucleus;
pub const softmax = @import("lib/softmax.zig").softmax;
66 changes: 29 additions & 37 deletions src/lib/sample_nucleus.zig
Original file line number Diff line number Diff line change
@@ -1,75 +1,67 @@
const std = @import("std");

// struct used when sorting probabilities during top-p sampling
pub const ProbIndex = struct { prob: f32, index: usize };
pub const ProbabilityIndexPair = struct {
probability: f32,
index: usize,
};

// The Curious Case of Neural Text Degeneration (https://arxiv.org/abs/1904.09751)
pub fn sampleNucleus(
rng_value: f32,
probabilities: []const f32,
probability_distribution: []const f32,
top_p: f32,
prob_indices_buffer: []ProbIndex,
probability_index_pairs_buffer: []ProbabilityIndexPair,
) usize {
@setFloatMode(.Optimized);

// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
std.debug.assert(probability_distribution.len > 0);

// elements smaller than (1 - top_p) / (probabilities.len - 1) cannot be part of the result
// and can be filtered out directly
// https://github.com/karpathy/llama2.c/commit/d421a95b2bfe593b2d9e5c147f3efc8d128afe0e
const cutoff: f32 = (1 - top_p) / @as(f32, @floatFromInt(probabilities.len - 1));
var probability_threshold: f32 =
(1 - top_p) / @as(f32, @floatFromInt(probability_distribution.len - 1));

var n0: usize = 0;
var n_probability_index_pairs: usize = 0;

for (probabilities, 0..) |prob, index| {
if (prob >= cutoff) {
prob_indices_buffer[n0].prob = prob;
prob_indices_buffer[n0].index = index;
n0 += 1;
for (probability_distribution, 0..) |probability, index| {
if (probability_threshold < probability) {
probability_index_pairs_buffer[n_probability_index_pairs].probability = probability;
probability_index_pairs_buffer[n_probability_index_pairs].index = index;
n_probability_index_pairs += 1;
}
}

var filtered_prob_indices = prob_indices_buffer[0..n0];
var probability_index_pairs = probability_index_pairs_buffer[0..n_probability_index_pairs];

// sort indices in descending order of probabilities
std.sort.block(ProbIndex, filtered_prob_indices, {}, lessThan);
std.sort.block(ProbabilityIndexPair, probability_index_pairs, {}, lessThan);

// truncate the list where cumulative probability exceeds topp
var cumulative_probability: f32 = 0;
var truncated_prob_indices: ?[]ProbIndex = null;

for (filtered_prob_indices, 0..) |prob_index, index| {
cumulative_probability += prob_index.prob;
for (probability_index_pairs, 0..) |probability_index_pair, index| {
cumulative_probability += probability_index_pair.probability;

if (cumulative_probability > top_p) {
truncated_prob_indices = filtered_prob_indices[0..(index + 1)];
probability_index_pairs = probability_index_pairs[0..(index + 1)];

break; // we've exceeded topp by including index
break;
}
}

// sample from the truncated list
const probability_threshold = rng_value * cumulative_probability;

probability_threshold = rng_value * cumulative_probability;
cumulative_probability = 0;

if (truncated_prob_indices) |prob_indices| {
for (prob_indices) |prob_index| {
cumulative_probability += prob_index.prob;
for (probability_index_pairs) |probability_index_pair| {
cumulative_probability += probability_index_pair.probability;

if (probability_threshold < cumulative_probability) {
return prob_index.index;
}
if (probability_threshold < cumulative_probability) {
return probability_index_pair.index;
}
}

return filtered_prob_indices[filtered_prob_indices.len - 1].index;
return probability_index_pairs[probability_index_pairs.len - 1].index;
}

fn lessThan(context: void, lhs: ProbIndex, rhs: ProbIndex) bool {
fn lessThan(context: void, lhs: ProbabilityIndexPair, rhs: ProbabilityIndexPair) bool {
_ = context;

return rhs.prob < lhs.prob;
return rhs.probability < lhs.probability;
}
12 changes: 10 additions & 2 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ pub fn main() !void {
var token: usize = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
var next: usize = 1; // TODO
var rng_state = args.random_seed;
var prob_indices: []lib.ProbIndex = try allocator.alloc(lib.ProbIndex, config.vocab_size);

var probability_index_pairs_buffer: []lib.ProbabilityIndexPair =
try allocator.alloc(lib.ProbabilityIndexPair, config.vocab_size);

var n_steps: usize = 0;

var start_time: i64 = 0;
Expand Down Expand Up @@ -102,7 +105,12 @@ pub fn main() !void {
next = lib.sampleMultinomial(lib.random(&rng_state), transformer.logits);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = lib.sampleNucleus(lib.random(&rng_state), transformer.logits, args.top_p, prob_indices);
next = lib.sampleNucleus(
lib.random(&rng_state),
transformer.logits,
args.top_p,
probability_index_pairs_buffer,
);
}
}

Expand Down

0 comments on commit 75f8105

Please sign in to comment.