diff --git a/src/attention.zig b/src/attention.zig index b16d480..578baf6 100644 --- a/src/attention.zig +++ b/src/attention.zig @@ -2,7 +2,6 @@ const std = @import("std"); const checkpoint = @import("checkpoint.zig"); const lib = @import("lib.zig"); -const utils = @import("utils.zig"); pub const Attention = struct { const Self = @This(); @@ -129,11 +128,11 @@ pub const Attention = struct { // calculate the attention score as the dot product of q and k // save the score to the attention buffer - attention_weights[position] = lib.dotProduct(query, key) / head_size_sqrt; + attention_weights[position] = lib.dot(query, key) / head_size_sqrt; } // softmax the scores to get attention weights, from 0..pos inclusively - utils.softmax(attention_weights[0..(current_position + 1)]); + lib.softmax(attention_weights[0..(current_position + 1)]); // weighted sum of the values, store back into intermediate_buffer const intermediate_buffer = self.input_buffer[query_head_offset..][0..head_size]; diff --git a/src/feed_forward.zig b/src/feed_forward.zig index 9bc190c..d0119f1 100644 --- a/src/feed_forward.zig +++ b/src/feed_forward.zig @@ -2,7 +2,6 @@ const std = @import("std"); const checkpoint = @import("checkpoint.zig"); const lib = @import("lib.zig"); -const utils = @import("utils.zig"); pub const FeedForward = struct { const Self = @This(); @@ -31,8 +30,6 @@ pub const FeedForward = struct { weights: *const checkpoint.Weights, layer: usize, ) !void { - @setFloatMode(.Optimized); - const dim = self.input_buffer.len; const hidden_dim = self.hidden_buffer.len; diff --git a/src/lib.zig b/src/lib.zig index 56badb0..8ebf922 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -1,9 +1,8 @@ -const linear_algebra = @import("lib/linear_algebra.zig"); - -pub const add = linear_algebra.add; -pub const dotProduct = linear_algebra.dotProduct; -pub const matmul = linear_algebra.matmul; -pub const matmul2 = linear_algebra.matmul2; -pub const matmul3 = linear_algebra.matmul3; +pub const add = @import("lib/add.zig").add; +pub const dot = @import("lib/dot.zig").dot; +pub const matmul = @import("lib/matmul.zig").matmul; +pub const matmul2 = @import("lib/matmul.zig").matmul2; +pub const matmul3 = @import("lib/matmul.zig").matmul3; pub const rmsnorm = @import("lib/rmsnorm.zig").rmsnorm; pub const rope = @import("lib/rope.zig").rope; +pub const softmax = @import("lib/softmax.zig").softmax; diff --git a/src/lib/add.zig b/src/lib/add.zig new file mode 100644 index 0000000..56ce89e --- /dev/null +++ b/src/lib/add.zig @@ -0,0 +1,11 @@ +const std = @import("std"); + +pub fn add(a: []f32, b: []const f32) void { + @setFloatMode(.Optimized); + + std.debug.assert(a.len == b.len); + + for (a, 0..) |*element, index| { + element.* += b[index]; + } +} diff --git a/src/lib/dot.zig b/src/lib/dot.zig new file mode 100644 index 0000000..f99b706 --- /dev/null +++ b/src/lib/dot.zig @@ -0,0 +1,41 @@ +const std = @import("std"); + +const max_vector_len: comptime_int = 16; +const min_vector_len: comptime_int = 4; + +pub fn dot(a: []const f32, b: []const f32) f32 { + @setFloatMode(.Optimized); + + std.debug.assert(a.len == b.len); + + const rest_len = a.len % max_vector_len; + + std.debug.assert(rest_len % min_vector_len == 0); + + var max_len_accu: @Vector(max_vector_len, f32) = @splat(0.0); + var index: usize = 0; + + while (index < a.len - rest_len) : (index += max_vector_len) { + max_len_accu += + @as(@Vector(max_vector_len, f32), a[index..][0..max_vector_len].*) * + @as(@Vector(max_vector_len, f32), b[index..][0..max_vector_len].*); + } + + var result = @reduce(.Add, max_len_accu); + + if (rest_len > 0) { + var min_len_accu: @Vector(min_vector_len, f32) = @splat(0.0); + + index = a.len - rest_len; + + while (index < a.len) : (index += min_vector_len) { + min_len_accu += + @as(@Vector(min_vector_len, f32), a[index..][0..min_vector_len].*) * + @as(@Vector(min_vector_len, f32), b[index..][0..min_vector_len].*); + } + + result += @reduce(.Add, min_len_accu); + } + + return result; +} diff --git a/src/lib/linear_algebra.zig b/src/lib/linear_algebra.zig deleted file mode 100644 index bd59a19..0000000 --- a/src/lib/linear_algebra.zig +++ /dev/null @@ -1,92 +0,0 @@ -const std = @import("std"); - -pub fn add(a: []f32, b: []const f32) void { - @setFloatMode(.Optimized); - - std.debug.assert(a.len == b.len); - - for (a, 0..) |*element, index| { - element.* += b[index]; - } -} - -const max_vector_len: comptime_int = 16; -const min_vector_len: comptime_int = 4; - -pub fn dotProduct(a: []const f32, b: []const f32) f32 { - @setFloatMode(.Optimized); - - std.debug.assert(a.len == b.len); - - const rest_len = a.len % max_vector_len; - - std.debug.assert(rest_len % min_vector_len == 0); - - var max_len_accu: @Vector(max_vector_len, f32) = @splat(0.0); - var index: usize = 0; - - while (index < a.len - rest_len) : (index += max_vector_len) { - max_len_accu += - @as(@Vector(max_vector_len, f32), a[index..][0..max_vector_len].*) * - @as(@Vector(max_vector_len, f32), b[index..][0..max_vector_len].*); - } - - var result = @reduce(.Add, max_len_accu); - - if (rest_len > 0) { - var min_len_accu: @Vector(min_vector_len, f32) = @splat(0.0); - - index = a.len - rest_len; - - while (index < a.len) : (index += min_vector_len) { - min_len_accu += - @as(@Vector(min_vector_len, f32), a[index..][0..min_vector_len].*) * - @as(@Vector(min_vector_len, f32), b[index..][0..min_vector_len].*); - } - - result += @reduce(.Add, min_len_accu); - } - - return result; -} - -pub fn matmul(result: []f32, a: []const f32, b: []const f32) void { - std.debug.assert(b.len == result.len * a.len); - - for (result, 0..) |*entry, i| { - entry.* = dotProduct(a, b[(i * a.len)..][0..a.len]); - } -} - -pub fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void { - const cpu_count = std.Thread.getCpuCount() catch 1; - - if (multi_threaded and cpu_count > 2) { - const thread_1 = try std.Thread.spawn(.{}, matmul, args_1); - const thread_2 = try std.Thread.spawn(.{}, matmul, args_2); - - thread_1.join(); - thread_2.join(); - } else { - @call(.auto, matmul, args_1); - @call(.auto, matmul, args_2); - } -} - -pub fn matmul3(args_1: anytype, args_2: anytype, args_3: anytype, multi_threaded: bool) !void { - const cpu_count = std.Thread.getCpuCount() catch 1; - - if (multi_threaded and cpu_count > 3) { - const thread_1 = try std.Thread.spawn(.{}, matmul, args_1); - const thread_2 = try std.Thread.spawn(.{}, matmul, args_2); - const thread_3 = try std.Thread.spawn(.{}, matmul, args_3); - - thread_1.join(); - thread_2.join(); - thread_3.join(); - } else { - @call(.auto, matmul, args_1); - @call(.auto, matmul, args_2); - @call(.auto, matmul, args_3); - } -} diff --git a/src/lib/matmul.zig b/src/lib/matmul.zig new file mode 100644 index 0000000..148908b --- /dev/null +++ b/src/lib/matmul.zig @@ -0,0 +1,44 @@ +const std = @import("std"); + +const dot = @import("dot.zig").dot; + +pub fn matmul(result: []f32, a: []const f32, b: []const f32) void { + std.debug.assert(b.len == result.len * a.len); + + for (result, 0..) |*entry, i| { + entry.* = dot(a, b[(i * a.len)..][0..a.len]); + } +} + +pub fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void { + const cpu_count = std.Thread.getCpuCount() catch 1; + + if (multi_threaded and cpu_count > 2) { + const thread_1 = try std.Thread.spawn(.{}, matmul, args_1); + const thread_2 = try std.Thread.spawn(.{}, matmul, args_2); + + thread_1.join(); + thread_2.join(); + } else { + @call(.auto, matmul, args_1); + @call(.auto, matmul, args_2); + } +} + +pub fn matmul3(args_1: anytype, args_2: anytype, args_3: anytype, multi_threaded: bool) !void { + const cpu_count = std.Thread.getCpuCount() catch 1; + + if (multi_threaded and cpu_count > 3) { + const thread_1 = try std.Thread.spawn(.{}, matmul, args_1); + const thread_2 = try std.Thread.spawn(.{}, matmul, args_2); + const thread_3 = try std.Thread.spawn(.{}, matmul, args_3); + + thread_1.join(); + thread_2.join(); + thread_3.join(); + } else { + @call(.auto, matmul, args_1); + @call(.auto, matmul, args_2); + @call(.auto, matmul, args_3); + } +} diff --git a/src/lib/softmax.zig b/src/lib/softmax.zig new file mode 100644 index 0000000..bfaa540 --- /dev/null +++ b/src/lib/softmax.zig @@ -0,0 +1,17 @@ +const std = @import("std"); + +pub fn softmax(vector: []f32) void { + @setFloatMode(.Optimized); + + var max: f32 = std.mem.max(f32, vector); + var sum: f32 = 0; + + for (vector) |*element| { + element.* = std.math.exp(element.* - max); + sum += element.*; + } + + for (vector) |*element| { + element.* /= sum; + } +} diff --git a/src/main.zig b/src/main.zig index 0e76e45..4ece617 100644 --- a/src/main.zig +++ b/src/main.zig @@ -2,6 +2,7 @@ const std = @import("std"); const checkpoint = @import("checkpoint.zig"); const cli = @import("cli.zig"); +const lib = @import("lib.zig"); const tokenizer = @import("tokenizer.zig"); const Transformer = @import("transformer.zig").Transformer; const utils = @import("utils.zig"); @@ -95,7 +96,7 @@ pub fn main() !void { } // apply softmax to the logits to get the probabilities for next token - utils.softmax(transformer.logits); + lib.softmax(transformer.logits); if (args.top_p <= 0 or args.top_p >= 1) { // we sample from this distribution to get the next token diff --git a/src/utils.zig b/src/utils.zig index 981f81b..af3cab2 100644 --- a/src/utils.zig +++ b/src/utils.zig @@ -1,8 +1,6 @@ const std = @import("std"); pub fn argmax(v: []f32) usize { - @setFloatMode(.Optimized); - // return argmax of v in elements 0..n var max_i: usize = 0; var max_p: f32 = v[0]; @@ -18,8 +16,6 @@ pub fn argmax(v: []f32) usize { } pub fn sample(rng: *std.rand.DefaultPrng, probabilities: []f32) usize { - @setFloatMode(.Optimized); - var r = rng.random().float(f32); var cdf: f32 = 0.0; @@ -105,23 +101,3 @@ fn lessThan(context: void, lhs: ProbIndex, rhs: ProbIndex) bool { return rhs.prob < lhs.prob; } - -pub fn softmax(x: []f32) void { - @setFloatMode(.Optimized); - - var max_val = std.mem.max(f32, x); - - // exp and sum - var sum: f32 = 0.0; - - for (x) |*item| { - item.* = std.math.exp(item.* - max_val); - - sum += item.*; - } - - // normalize - for (x) |*item| { - item.* /= sum; - } -}