diff --git a/src/matrix.zig b/src/matrix.zig index 8d8a9b6..86e368f 100644 --- a/src/matrix.zig +++ b/src/matrix.zig @@ -36,7 +36,7 @@ pub fn readMultipleLeaky( return matrices; } -const max_thread_count = 8; +const max_thread_count = 24; pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void { if (self.thread_count == 0) { @@ -45,27 +45,30 @@ pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void { return; } - const n_threads = @min(try std.Thread.getCpuCount(), max_thread_count, self.thread_count); - - if (output.values.len % n_threads != 0) { - return error.UnsupportedThreadCount; - } - - const partial_length = output.values.len / n_threads; + const n_threads = @min(max_thread_count, self.thread_count); + const thread_chunk_size = output.values.len / n_threads; var threads: [max_thread_count]std.Thread = undefined; for (threads[0..n_threads], 0..) |*thread, index| { thread.* = try std.Thread.spawn(.{}, computeMatrixVectorMultiplication, .{ - self.rows[index * partial_length .. (index + 1) * partial_length], + self.rows[index * thread_chunk_size ..][0..thread_chunk_size], input, - output.values[index * partial_length .. (index + 1) * partial_length], + output.values[index * thread_chunk_size ..][0..thread_chunk_size], }); } for (threads[0..n_threads]) |thread| { thread.join(); } + + if (output.values.len % n_threads > 0) { + try computeMatrixVectorMultiplication( + self.rows[n_threads * thread_chunk_size ..], + input, + output.values[n_threads * thread_chunk_size ..], + ); + } } fn computeMatrixVectorMultiplication(