Skip to content

Commit

Permalink
More files
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 22, 2023
1 parent 93c90ce commit 77f91c5
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 130 deletions.
5 changes: 2 additions & 3 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ const std = @import("std");

const checkpoint = @import("checkpoint.zig");
const lib = @import("lib.zig");
const utils = @import("utils.zig");

pub const Attention = struct {
const Self = @This();
Expand Down Expand Up @@ -129,11 +128,11 @@ pub const Attention = struct {

// calculate the attention score as the dot product of q and k
// save the score to the attention buffer
attention_weights[position] = lib.dotProduct(query, key) / head_size_sqrt;
attention_weights[position] = lib.dot(query, key) / head_size_sqrt;
}

// softmax the scores to get attention weights, from 0..pos inclusively
utils.softmax(attention_weights[0..(current_position + 1)]);
lib.softmax(attention_weights[0..(current_position + 1)]);

// weighted sum of the values, store back into intermediate_buffer
const intermediate_buffer = self.input_buffer[query_head_offset..][0..head_size];
Expand Down
3 changes: 0 additions & 3 deletions src/feed_forward.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ const std = @import("std");

const checkpoint = @import("checkpoint.zig");
const lib = @import("lib.zig");
const utils = @import("utils.zig");

pub const FeedForward = struct {
const Self = @This();
Expand Down Expand Up @@ -31,8 +30,6 @@ pub const FeedForward = struct {
weights: *const checkpoint.Weights,
layer: usize,
) !void {
@setFloatMode(.Optimized);

const dim = self.input_buffer.len;
const hidden_dim = self.hidden_buffer.len;

Expand Down
13 changes: 6 additions & 7 deletions src/lib.zig
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
const linear_algebra = @import("lib/linear_algebra.zig");

pub const add = linear_algebra.add;
pub const dotProduct = linear_algebra.dotProduct;
pub const matmul = linear_algebra.matmul;
pub const matmul2 = linear_algebra.matmul2;
pub const matmul3 = linear_algebra.matmul3;
pub const add = @import("lib/add.zig").add;
pub const dot = @import("lib/dot.zig").dot;
pub const matmul = @import("lib/matmul.zig").matmul;
pub const matmul2 = @import("lib/matmul.zig").matmul2;
pub const matmul3 = @import("lib/matmul.zig").matmul3;
pub const rmsnorm = @import("lib/rmsnorm.zig").rmsnorm;
pub const rope = @import("lib/rope.zig").rope;
pub const softmax = @import("lib/softmax.zig").softmax;
11 changes: 11 additions & 0 deletions src/lib/add.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
const std = @import("std");

pub fn add(a: []f32, b: []const f32) void {
@setFloatMode(.Optimized);

std.debug.assert(a.len == b.len);

for (a, 0..) |*element, index| {
element.* += b[index];
}
}
41 changes: 41 additions & 0 deletions src/lib/dot.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
const std = @import("std");

const max_vector_len: comptime_int = 16;
const min_vector_len: comptime_int = 4;

pub fn dot(a: []const f32, b: []const f32) f32 {
@setFloatMode(.Optimized);

std.debug.assert(a.len == b.len);

const rest_len = a.len % max_vector_len;

std.debug.assert(rest_len % min_vector_len == 0);

var max_len_accu: @Vector(max_vector_len, f32) = @splat(0.0);
var index: usize = 0;

while (index < a.len - rest_len) : (index += max_vector_len) {
max_len_accu +=
@as(@Vector(max_vector_len, f32), a[index..][0..max_vector_len].*) *
@as(@Vector(max_vector_len, f32), b[index..][0..max_vector_len].*);
}

var result = @reduce(.Add, max_len_accu);

if (rest_len > 0) {
var min_len_accu: @Vector(min_vector_len, f32) = @splat(0.0);

index = a.len - rest_len;

while (index < a.len) : (index += min_vector_len) {
min_len_accu +=
@as(@Vector(min_vector_len, f32), a[index..][0..min_vector_len].*) *
@as(@Vector(min_vector_len, f32), b[index..][0..min_vector_len].*);
}

result += @reduce(.Add, min_len_accu);
}

return result;
}
92 changes: 0 additions & 92 deletions src/lib/linear_algebra.zig

This file was deleted.

44 changes: 44 additions & 0 deletions src/lib/matmul.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
const std = @import("std");

const dot = @import("dot.zig").dot;

pub fn matmul(result: []f32, a: []const f32, b: []const f32) void {
std.debug.assert(b.len == result.len * a.len);

for (result, 0..) |*entry, i| {
entry.* = dot(a, b[(i * a.len)..][0..a.len]);
}
}

pub fn matmul2(args_1: anytype, args_2: anytype, multi_threaded: bool) !void {
const cpu_count = std.Thread.getCpuCount() catch 1;

if (multi_threaded and cpu_count > 2) {
const thread_1 = try std.Thread.spawn(.{}, matmul, args_1);
const thread_2 = try std.Thread.spawn(.{}, matmul, args_2);

thread_1.join();
thread_2.join();
} else {
@call(.auto, matmul, args_1);
@call(.auto, matmul, args_2);
}
}

pub fn matmul3(args_1: anytype, args_2: anytype, args_3: anytype, multi_threaded: bool) !void {
const cpu_count = std.Thread.getCpuCount() catch 1;

if (multi_threaded and cpu_count > 3) {
const thread_1 = try std.Thread.spawn(.{}, matmul, args_1);
const thread_2 = try std.Thread.spawn(.{}, matmul, args_2);
const thread_3 = try std.Thread.spawn(.{}, matmul, args_3);

thread_1.join();
thread_2.join();
thread_3.join();
} else {
@call(.auto, matmul, args_1);
@call(.auto, matmul, args_2);
@call(.auto, matmul, args_3);
}
}
17 changes: 17 additions & 0 deletions src/lib/softmax.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
const std = @import("std");

pub fn softmax(vector: []f32) void {
@setFloatMode(.Optimized);

var max: f32 = std.mem.max(f32, vector);
var sum: f32 = 0;

for (vector) |*element| {
element.* = std.math.exp(element.* - max);
sum += element.*;
}

for (vector) |*element| {
element.* /= sum;
}
}
3 changes: 2 additions & 1 deletion src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const std = @import("std");

const checkpoint = @import("checkpoint.zig");
const cli = @import("cli.zig");
const lib = @import("lib.zig");
const tokenizer = @import("tokenizer.zig");
const Transformer = @import("transformer.zig").Transformer;
const utils = @import("utils.zig");
Expand Down Expand Up @@ -95,7 +96,7 @@ pub fn main() !void {
}

// apply softmax to the logits to get the probabilities for next token
utils.softmax(transformer.logits);
lib.softmax(transformer.logits);

if (args.top_p <= 0 or args.top_p >= 1) {
// we sample from this distribution to get the next token
Expand Down
24 changes: 0 additions & 24 deletions src/utils.zig
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
const std = @import("std");

pub fn argmax(v: []f32) usize {
@setFloatMode(.Optimized);

// return argmax of v in elements 0..n
var max_i: usize = 0;
var max_p: f32 = v[0];
Expand All @@ -18,8 +16,6 @@ pub fn argmax(v: []f32) usize {
}

pub fn sample(rng: *std.rand.DefaultPrng, probabilities: []f32) usize {
@setFloatMode(.Optimized);

var r = rng.random().float(f32);
var cdf: f32 = 0.0;

Expand Down Expand Up @@ -105,23 +101,3 @@ fn lessThan(context: void, lhs: ProbIndex, rhs: ProbIndex) bool {

return rhs.prob < lhs.prob;
}

pub fn softmax(x: []f32) void {
@setFloatMode(.Optimized);

var max_val = std.mem.max(f32, x);

// exp and sum
var sum: f32 = 0.0;

for (x) |*item| {
item.* = std.math.exp(item.* - max_val);

sum += item.*;
}

// normalize
for (x) |*item| {
item.* /= sum;
}
}

0 comments on commit 77f91c5

Please sign in to comment.