Skip to content

Commit

Permalink
Refactor attention
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 23, 2023
1 parent 388f6c5 commit 684c140
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ const lib = @import("lib.zig");
pub const Attention = struct {
const Self = @This();

n_heads: usize,
seq_len: usize,
n_groups: usize,
head_size: usize,
head_size_sqrt: f32,
n_groups: usize,
n_heads: usize,
seq_len: usize,

input_buffer: []f32,
output_buffer: []f32,
Expand All @@ -22,13 +22,14 @@ pub const Attention = struct {
value_cache: []f32,

pub fn init(self: *Self, allocator: std.mem.Allocator, config: *const checkpoint.Config) !void {
const kv_dim = (config.dim * config.n_kv_heads) / config.n_heads;

self.n_heads = config.n_heads;
self.seq_len = config.seq_len;
self.n_groups = config.n_heads / config.n_kv_heads;
self.head_size = config.dim / config.n_heads;
self.head_size_sqrt = std.math.sqrt(@as(f32, @floatFromInt(self.head_size)));
self.n_groups = config.n_heads / config.n_kv_heads;
self.n_heads = config.n_heads;
self.seq_len = config.seq_len;

const kv_dim = (config.dim * config.n_kv_heads) / config.n_heads;

self.input_buffer = try allocator.alloc(f32, config.dim);
self.output_buffer = try allocator.alloc(f32, config.dim);
self.scores_buffer = try allocator.alloc(f32, config.n_heads * config.seq_len);
Expand Down Expand Up @@ -96,7 +97,7 @@ pub const Attention = struct {
);

for (0..self.n_heads) |head| {
self.compute_attention(pos, head, kv_cache_layer_offset);
self.compute_weighted_values(pos, head, kv_cache_layer_offset);
}

lib.matmul(
Expand All @@ -106,7 +107,7 @@ pub const Attention = struct {
);
}

fn compute_attention(
fn compute_weighted_values(
self: *const Self,
pos: usize,
head: usize,
Expand Down

0 comments on commit 684c140

Please sign in to comment.