Skip to content

Commit d5585bc

Browse files
authored
Implement threaded BLAKE3 (#25587)
Allows BLAKE3 to be computed using multiple threads.
1 parent 5a38dd2 commit d5585bc

File tree

2 files changed

+263
-3
lines changed

2 files changed

+263
-3
lines changed

lib/std/crypto/benchmark.zig

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ const hashes = [_]Crypto{
3535
Crypto{ .ty = crypto.hash.Blake3, .name = "blake3" },
3636
};
3737

38+
const parallel_hashes = [_]Crypto{
39+
Crypto{ .ty = crypto.hash.Blake3, .name = "blake3-parallel" },
40+
};
41+
3842
const block_size: usize = 8 * 8192;
3943

4044
pub fn benchmarkHash(comptime Hash: anytype, comptime bytes: comptime_int) !u64 {
@@ -61,6 +65,25 @@ pub fn benchmarkHash(comptime Hash: anytype, comptime bytes: comptime_int) !u64
6165
return throughput;
6266
}
6367

68+
pub fn benchmarkHashParallel(comptime Hash: anytype, comptime bytes: comptime_int, allocator: mem.Allocator, io: std.Io) !u64 {
69+
const data: []u8 = try allocator.alloc(u8, bytes);
70+
defer allocator.free(data);
71+
random.bytes(data);
72+
73+
var timer = try Timer.start();
74+
const start = timer.lap();
75+
var final: [Hash.digest_length]u8 = undefined;
76+
try Hash.hashParallel(data, &final, .{}, allocator, io);
77+
std.mem.doNotOptimizeAway(final);
78+
79+
const end = timer.read();
80+
81+
const elapsed_s = @as(f64, @floatFromInt(end - start)) / time.ns_per_s;
82+
const throughput = @as(u64, @intFromFloat(bytes / elapsed_s));
83+
84+
return throughput;
85+
}
86+
6487
const macs = [_]Crypto{
6588
Crypto{ .ty = crypto.onetimeauth.Ghash, .name = "ghash" },
6689
Crypto{ .ty = crypto.onetimeauth.Polyval, .name = "polyval" },
@@ -512,6 +535,18 @@ pub fn main() !void {
512535
}
513536
}
514537

538+
var io_threaded = std.Io.Threaded.init(arena_allocator);
539+
defer io_threaded.deinit();
540+
const io = io_threaded.io();
541+
542+
inline for (parallel_hashes) |H| {
543+
if (filter == null or std.mem.indexOf(u8, H.name, filter.?) != null) {
544+
const throughput = try benchmarkHashParallel(H.ty, mode(128 * MiB), arena_allocator, io);
545+
try stdout.print("{s:>17}: {:10} MiB/s\n", .{ H.name, throughput / (1 * MiB) });
546+
try stdout.flush();
547+
}
548+
}
549+
515550
inline for (macs) |M| {
516551
if (filter == null or std.mem.indexOf(u8, M.name, filter.?) != null) {
517552
const throughput = try benchmarkMac(M.ty, mode(128 * MiB));

lib/std/crypto/blake3.zig

Lines changed: 228 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ const std = @import("std");
22
const builtin = @import("builtin");
33
const fmt = std.fmt;
44
const mem = std.mem;
5+
const Io = std.Io;
6+
const Thread = std.Thread;
57

68
const Vec4 = @Vector(4, u32);
79
const Vec8 = @Vector(8, u32);
@@ -14,6 +16,11 @@ pub const simd_degree = std.simd.suggestVectorLength(u32) orelse 1;
1416
pub const max_simd_degree = simd_degree;
1517
const max_simd_degree_or_2 = if (max_simd_degree > 2) max_simd_degree else 2;
1618

19+
/// Threshold for switching to parallel processing.
20+
/// Below this size, sequential hashing is used.
21+
/// Benchmarks generally show significant speedup starting at 3 MiB.
22+
const parallel_threshold = 3 * 1024 * 1024;
23+
1724
const iv: [8]u32 = .{
1825
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A,
1926
0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
@@ -666,6 +673,95 @@ fn leftSubtreeLen(input_len: usize) usize {
666673
return @intCast(roundDownToPowerOf2(full_chunks) * chunk_length);
667674
}
668675

676+
const ChunkBatch = struct {
677+
input: []const u8,
678+
start_chunk: usize,
679+
end_chunk: usize,
680+
cvs: [][8]u32,
681+
key: [8]u32,
682+
flags: Flags,
683+
684+
fn process(ctx: ChunkBatch) void {
685+
var cv_buffer: [max_simd_degree * Blake3.digest_length]u8 = undefined;
686+
var chunk_idx = ctx.start_chunk;
687+
688+
while (chunk_idx < ctx.end_chunk) {
689+
const remaining = ctx.end_chunk - chunk_idx;
690+
const batch_size = @min(remaining, max_simd_degree);
691+
const offset = chunk_idx * chunk_length;
692+
const batch_len = @as(usize, batch_size) * chunk_length;
693+
694+
const num_cvs = compressChunksParallel(
695+
ctx.input[offset..][0..batch_len],
696+
ctx.key,
697+
chunk_idx,
698+
ctx.flags,
699+
&cv_buffer,
700+
);
701+
702+
for (0..num_cvs) |i| {
703+
const cv_bytes = cv_buffer[i * Blake3.digest_length ..][0..Blake3.digest_length];
704+
ctx.cvs[chunk_idx + i] = loadCvWords(cv_bytes.*);
705+
}
706+
707+
chunk_idx += batch_size;
708+
}
709+
}
710+
};
711+
712+
const ParentBatchContext = struct {
713+
input_cvs: [][8]u32,
714+
output_cvs: [][8]u32,
715+
start_idx: usize,
716+
end_idx: usize,
717+
key: [8]u32,
718+
flags: Flags,
719+
};
720+
721+
fn processParentBatch(ctx: ParentBatchContext) void {
722+
for (ctx.start_idx..ctx.end_idx) |i| {
723+
const output = parentOutputFromCvs(ctx.input_cvs[i * 2], ctx.input_cvs[i * 2 + 1], ctx.key, ctx.flags);
724+
ctx.output_cvs[i] = output.chainingValue();
725+
}
726+
}
727+
728+
fn buildMerkleTreeLayerParallel(
729+
input_cvs: [][8]u32,
730+
output_cvs: [][8]u32,
731+
key: [8]u32,
732+
flags: Flags,
733+
io: Io,
734+
) void {
735+
const num_parents = input_cvs.len / 2;
736+
737+
if (num_parents <= 16) {
738+
for (0..num_parents) |i| {
739+
const output = parentOutputFromCvs(input_cvs[i * 2], input_cvs[i * 2 + 1], key, flags);
740+
output_cvs[i] = output.chainingValue();
741+
}
742+
return;
743+
}
744+
745+
const num_workers = Thread.getCpuCount() catch 1;
746+
const parents_per_worker = (num_parents + num_workers - 1) / num_workers;
747+
var group: Io.Group = .init;
748+
749+
for (0..num_workers) |worker_id| {
750+
const start_idx = worker_id * parents_per_worker;
751+
if (start_idx >= num_parents) break;
752+
753+
group.async(io, processParentBatch, .{ParentBatchContext{
754+
.input_cvs = input_cvs,
755+
.output_cvs = output_cvs,
756+
.start_idx = start_idx,
757+
.end_idx = @min(start_idx + parents_per_worker, num_parents),
758+
.key = key,
759+
.flags = flags,
760+
}});
761+
}
762+
group.wait(io);
763+
}
764+
669765
fn parentOutput(parent_block: []const u8, key: [8]u32, flags: Flags) Output {
670766
var block: [Blake3.block_length]u8 = undefined;
671767
@memcpy(&block, parent_block[0..Blake3.block_length]);
@@ -705,7 +801,7 @@ const ChunkState = struct {
705801
return ChunkState{
706802
.cv = key,
707803
.chunk_counter = 0,
708-
.buf = [_]u8{0} ** Blake3.block_length,
804+
.buf = @splat(0),
709805
.buf_len = 0,
710806
.blocks_compressed = 0,
711807
.flags = flags,
@@ -716,7 +812,7 @@ const ChunkState = struct {
716812
self.cv = key;
717813
self.chunk_counter = chunk_counter;
718814
self.blocks_compressed = 0;
719-
self.buf = [_]u8{0} ** Blake3.block_length;
815+
self.buf = @splat(0);
720816
self.buf_len = 0;
721817
}
722818

@@ -742,7 +838,7 @@ const ChunkState = struct {
742838
if (self.buf_len == Blake3.block_length) {
743839
compressInPlace(&self.cv, &self.buf, Blake3.block_length, self.chunk_counter, self.flags.with(self.maybeStartFlag()));
744840
self.blocks_compressed += 1;
745-
self.buf = [_]u8{0} ** Blake3.block_length;
841+
self.buf = @splat(0);
746842
self.buf_len = 0;
747843
}
748844

@@ -849,6 +945,90 @@ pub const Blake3 = struct {
849945
d.final(out);
850946
}
851947

948+
pub fn hashParallel(b: []const u8, out: []u8, options: Options, allocator: std.mem.Allocator, io: Io) !void {
949+
if (b.len < parallel_threshold) {
950+
return hash(b, out, options);
951+
}
952+
953+
const key_words = if (options.key) |key| loadKeyWords(key) else iv;
954+
const flags: Flags = if (options.key != null) .{ .keyed_hash = true } else .{};
955+
956+
const num_full_chunks = b.len / chunk_length;
957+
const thread_count = Thread.getCpuCount() catch 1;
958+
if (thread_count <= 1 or num_full_chunks == 0) {
959+
return hash(b, out, options);
960+
}
961+
962+
const cvs = try allocator.alloc([8]u32, num_full_chunks);
963+
defer allocator.free(cvs);
964+
965+
// Process chunks in parallel
966+
const num_workers = thread_count;
967+
const chunks_per_worker = (num_full_chunks + num_workers - 1) / num_workers;
968+
var group: Io.Group = .init;
969+
970+
for (0..num_workers) |worker_id| {
971+
const start_chunk = worker_id * chunks_per_worker;
972+
if (start_chunk >= num_full_chunks) break;
973+
974+
group.async(io, ChunkBatch.process, .{ChunkBatch{
975+
.input = b,
976+
.start_chunk = start_chunk,
977+
.end_chunk = @min(start_chunk + chunks_per_worker, num_full_chunks),
978+
.cvs = cvs,
979+
.key = key_words,
980+
.flags = flags,
981+
}});
982+
}
983+
group.wait(io);
984+
985+
// Build Merkle tree in parallel layers using ping-pong buffers
986+
const max_intermediate_size = (num_full_chunks + 1) / 2;
987+
const buffer0 = try allocator.alloc([8]u32, max_intermediate_size);
988+
defer allocator.free(buffer0);
989+
const buffer1 = try allocator.alloc([8]u32, max_intermediate_size);
990+
defer allocator.free(buffer1);
991+
992+
var current_level = cvs;
993+
var next_level_buf = buffer0;
994+
var toggle = false;
995+
996+
while (current_level.len > 8) {
997+
const num_parents = current_level.len / 2;
998+
const has_odd = current_level.len % 2 == 1;
999+
const next_level_size = num_parents + @intFromBool(has_odd);
1000+
1001+
buildMerkleTreeLayerParallel(
1002+
current_level[0 .. num_parents * 2],
1003+
next_level_buf[0..num_parents],
1004+
key_words,
1005+
flags,
1006+
io,
1007+
);
1008+
1009+
if (has_odd) {
1010+
next_level_buf[num_parents] = current_level[current_level.len - 1];
1011+
}
1012+
1013+
current_level = next_level_buf[0..next_level_size];
1014+
next_level_buf = if (toggle) buffer0 else buffer1;
1015+
toggle = !toggle;
1016+
}
1017+
1018+
// Finalize remaining small tree sequentially
1019+
var hasher = init_internal(key_words, flags);
1020+
for (current_level, 0..) |cv, i| hasher.pushCv(cv, i);
1021+
1022+
hasher.chunk.chunk_counter = num_full_chunks;
1023+
const remaining_bytes = b.len % chunk_length;
1024+
if (remaining_bytes > 0) {
1025+
hasher.chunk.update(b[num_full_chunks * chunk_length ..]);
1026+
hasher.mergeCvStack(hasher.chunk.chunk_counter);
1027+
}
1028+
1029+
hasher.final(out);
1030+
}
1031+
8521032
fn init_internal(key: [8]u32, flags: Flags) Blake3 {
8531033
return Blake3{
8541034
.key = key,
@@ -1182,3 +1362,48 @@ test "BLAKE3 reference test cases" {
11821362
try testBlake3(derive_key, t.input_len, t.derive_key.*);
11831363
}
11841364
}
1365+
1366+
test "BLAKE3 parallel vs sequential" {
1367+
const allocator = std.testing.allocator;
1368+
const io = std.testing.io;
1369+
1370+
// Test various sizes including those above the parallelization threshold
1371+
const test_sizes = [_]usize{
1372+
0, // Empty
1373+
64, // One block
1374+
1024, // One chunk
1375+
1024 * 10, // Multiple chunks
1376+
1024 * 100, // 100KB
1377+
1024 * 1000, // 1MB
1378+
1024 * 5000, // 5MB (above threshold)
1379+
1024 * 10000, // 10MB (above threshold)
1380+
};
1381+
1382+
for (test_sizes) |size| {
1383+
// Allocate and fill test data with a pattern
1384+
const input = try allocator.alloc(u8, size);
1385+
defer allocator.free(input);
1386+
for (input, 0..) |*byte, i| {
1387+
byte.* = @truncate(i);
1388+
}
1389+
1390+
// Test regular hash
1391+
var expected: [32]u8 = undefined;
1392+
Blake3.hash(input, &expected, .{});
1393+
1394+
var actual: [32]u8 = undefined;
1395+
try Blake3.hashParallel(input, &actual, .{}, allocator, io);
1396+
1397+
try std.testing.expectEqualSlices(u8, &expected, &actual);
1398+
1399+
// Test keyed hash
1400+
const key: [32]u8 = @splat(0x42);
1401+
var expected_keyed: [32]u8 = undefined;
1402+
Blake3.hash(input, &expected_keyed, .{ .key = key });
1403+
1404+
var actual_keyed: [32]u8 = undefined;
1405+
try Blake3.hashParallel(input, &actual_keyed, .{ .key = key }, allocator, io);
1406+
1407+
try std.testing.expectEqualSlices(u8, &expected_keyed, &actual_keyed);
1408+
}
1409+
}

0 commit comments

Comments
 (0)