diff --git a/Cargo.lock b/Cargo.lock index 553a035636..9175c37266 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1420,10 +1420,11 @@ dependencies = [ [[package]] name = "cubecl" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d#7854b343975c990dd8bb1b4b68b3bc9bda488c1d" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" dependencies = [ "cubecl-core", "cubecl-cuda", + "cubecl-hip", "cubecl-linalg", "cubecl-runtime", "cubecl-wgpu", @@ -1432,7 +1433,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d#7854b343975c990dd8bb1b4b68b3bc9bda488c1d" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" dependencies = [ "derive-new", "embassy-futures", @@ -1448,7 +1449,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d#7854b343975c990dd8bb1b4b68b3bc9bda488c1d" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" dependencies = [ "bytemuck", "cubecl-common", @@ -1465,7 +1466,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d#7854b343975c990dd8bb1b4b68b3bc9bda488c1d" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" dependencies = [ "bytemuck", "cubecl-common", @@ -1477,10 +1478,34 @@ dependencies = [ "log", ] +[[package]] +name = "cubecl-hip" +version = "0.2.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" +dependencies = [ + "bytemuck", + "cubecl-common", + "cubecl-core", + "cubecl-hip-sys", + "cubecl-runtime", + "derive-new", + "half", + "log", +] + +[[package]] +name = "cubecl-hip-sys" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daa04c96e99d6983de13345fc8175fe1479b1076019686f6344383b3c0db8bb1" +dependencies = [ + "libc", +] + [[package]] name = "cubecl-linalg" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d#7854b343975c990dd8bb1b4b68b3bc9bda488c1d" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" dependencies = [ "bytemuck", "cubecl-core", @@ -1491,7 +1516,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d#7854b343975c990dd8bb1b4b68b3bc9bda488c1d" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" dependencies = [ "cubecl-common", "darling", @@ -1506,7 +1531,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d#7854b343975c990dd8bb1b4b68b3bc9bda488c1d" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" dependencies = [ "async-channel", "cfg_aliases 0.2.1", @@ -1525,7 +1550,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.2.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d#7854b343975c990dd8bb1b4b68b3bc9bda488c1d" +source = "git+https://github.com/tracel-ai/cubecl?rev=609183cccb26e2ccaa834f2230b5b8f38f9ea507#609183cccb26e2ccaa834f2230b5b8f38f9ea507" dependencies = [ "async-channel", "bytemuck", diff --git a/Cargo.toml b/Cargo.toml index b42bbdbfdc..2066a32420 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -152,8 +152,8 @@ tch = "0.15.0" portable-atomic-util = { version = "0.2.2", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7854b343975c990dd8bb1b4b68b3bc9bda488c1d" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7854b343975c990dd8bb1b4b68b3bc9bda488c1d" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "609183cccb26e2ccaa834f2230b5b8f38f9ea507" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "609183cccb26e2ccaa834f2230b5b8f38f9ea507" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index bb1c2eca26..3f930a1174 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -34,7 +34,7 @@ burn-common = { path = "../crates/burn-common", version = "0.15.0" } burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0", optional = true } clap = { workspace = true } colored = { workspace = true } -cubecl = { workspace = true, features = ["wgpu"] } +cubecl = { workspace = true, features = ["wgpu"], default-features = true } derive-new = { workspace = true } dirs = { workspace = true } github-device-flow = { workspace = true } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index d02ca11ce4..617833ec11 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -1,5 +1,5 @@ use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps}, + ops::{conv::calculate_conv_output_size, ConvOptions}, Shape, }; use cmma::{Matrix, MatrixIdent, MatrixLayout}; @@ -8,9 +8,12 @@ use half::f16; use crate::{ kernel::{into_contiguous, slice}, - ops::{numeric::empty_device, permute, reshape}, + ops::{ + numeric::{empty_device, zeros_device}, + permute, + }, tensor::JitTensor, - FloatElement, IntElement, JitBackend, JitRuntime, + FloatElement, IntElement, JitRuntime, }; /// Perform a 2D convolution using the implicit GEMM algorithm. Requires `cmma` to be available. @@ -20,6 +23,7 @@ use crate::{ /// * `bias` - The bias added to each channel /// * `options` - The options to use for the convolution /// +#[allow(clippy::extra_unused_type_parameters)] pub fn conv2d_implicit_gemm( input: JitTensor, weight: JitTensor, @@ -78,6 +82,7 @@ pub fn conv2d_implicit_gemm( let input_tile_size = cmma_m * cmma_k; let weight_tile_size = cmma_k * cmma_n; + let acc_tile_size = cmma_m * cmma_n; let warp_size = 32; let warps_per_cube = (cube_dim_y * cube_dim_x) / warp_size; @@ -90,6 +95,13 @@ pub fn conv2d_implicit_gemm( let weight_elems_per_thread = weight_tile_size / warp_size; let weight_vectorization = find_common_vec(in_channels, weight_elems_per_thread, supported_vecs); + let bias_elems_per_thread = acc_tile_size / warp_size; + let bias_vectorization = find_common_vec(out_channels, bias_elems_per_thread, supported_vecs); + + let has_bias = bias.is_some(); + let bias = bias.unwrap_or_else(|| { + zeros_device(input.client.clone(), input.device.clone(), Shape::new([1])) + }); let settings = GemmSettings { cmma_m, @@ -126,6 +138,7 @@ pub fn conv2d_implicit_gemm( cube_dim, input.as_tensor_arg(input_vectorization), weight.as_tensor_arg(weight_vectorization), + bias.as_tensor_arg(bias_vectorization), out.as_tensor_arg(1), DimensionsLaunch::new( ScalarArg::new(gemm_m), @@ -152,15 +165,11 @@ pub fn conv2d_implicit_gemm( padding_h: options.padding[0] as i32, padding_w: options.padding[1] as i32, aligned, + has_bias, }, ); - let mut out = slice(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]); - - if let Some(bias) = bias { - let bias = reshape(bias, Shape::new([1, 1, 1, out_channels])); - out = JitBackend::::float_add(out, bias); - } + let out = slice(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]); // Reset to NCHW permute(out, &[0, 3, 1, 2]) @@ -224,6 +233,7 @@ struct ConvSettings { padding_h: i32, padding_w: i32, aligned: bool, + has_bias: bool, } #[derive(Clone, Copy, CubeType)] @@ -247,12 +257,15 @@ struct Matrices { fn implicit_gemm_kernel( input: &Tensor>, weight: &Tensor>, + bias: &Tensor>, out: &mut Tensor, dims: &Dimensions, args: &ConvArgs, #[comptime] gemm_settings: GemmSettings, - #[comptime] kernel_settings: ConvSettings, + #[comptime] conv_settings: ConvSettings, ) { + let _ = bias[0]; + let GemmSettings { cmma_m, cmma_n, @@ -283,10 +296,11 @@ fn implicit_gemm_kernel( let out_pos = pos.global_n + pos.global_m * dims.gemm_n; let out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size); - if kernel_settings.aligned || pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n { + if conv_settings.aligned || pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n { execute_gemm( input, weight, + bias, out, input_tile, weight_tile, @@ -294,7 +308,7 @@ fn implicit_gemm_kernel( &pos, args, gemm_settings, - kernel_settings, + conv_settings, ); } } @@ -330,6 +344,7 @@ fn calculate_positions(#[comptime] gemm_settings: GemmSettings) -> Positions { #[cube] fn make_matrices( #[comptime] gemm_settings: GemmSettings, + #[comptime] has_bias: bool, ) -> Matrices { let GemmSettings { cmma_m, @@ -338,6 +353,27 @@ fn make_matrices( .. } = gemm_settings; + let acc = if has_bias { + unsafe { + Matrix::::uninitialized( + MatrixIdent::Accumulator, + cmma_m, + cmma_n, + cmma_k, + MatrixLayout::Undefined, + ) + } + } else { + Matrix::::from_value( + MatrixIdent::Accumulator, + cmma_m, + cmma_n, + cmma_k, + MatrixLayout::Undefined, + FAcc::new(0.0), + ) + }; + Matrices:: { a: unsafe { Matrix::::uninitialized( @@ -357,14 +393,7 @@ fn make_matrices( MatrixLayout::ColMajor, ) }, - acc: Matrix::::from_value( - MatrixIdent::Accumulator, - cmma_m, - cmma_n, - cmma_k, - MatrixLayout::Undefined, - FAcc::new(0.0), - ), + acc, } } @@ -372,6 +401,7 @@ fn make_matrices( fn execute_gemm( input: &Tensor>, weight: &Tensor>, + bias: &Tensor>, out: &mut SliceMut, input_tile: &mut SliceMut, weight_tile: &mut SliceMut, @@ -381,9 +411,26 @@ fn execute_gemm( #[comptime] g_settings: GemmSettings, #[comptime] k_settings: ConvSettings, ) { - let GemmSettings { cmma_n, cmma_k, .. } = g_settings; - - let matrices = make_matrices::(g_settings); + let GemmSettings { + cmma_m, + cmma_n, + cmma_k, + warps_per_cube, + .. + } = g_settings; + let has_bias = k_settings.has_bias; + + let matrices = make_matrices::(g_settings, has_bias); + if has_bias { + let mut smem_bias = SharedMemory::new(cmma_m * cmma_n * warps_per_cube); + load_bias_tile(bias, &mut smem_bias, pos, g_settings); + cmma::load_with_layout( + &matrices.acc, + smem_bias.as_slice(), + cmma_n, + MatrixLayout::RowMajor, + ); + } // Loop over the K-dimension for k in range_stepped(0, dims.gemm_k, cmma_k) { @@ -472,7 +519,7 @@ fn load_input_tile( // Slices are always `kernel_size * channels` elements wide so we can compute where inside a slice // we are and also which row the slice is in relative to the start of the CMMA matrix - // Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is responsible for + // Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is repsonsible for let my_slice_idx = (slice_start_idx + (m % cmma_k)) % dims.slice_size; let channel = my_slice_idx % channels; @@ -564,6 +611,40 @@ fn load_weight_tile( } } +#[cube] +fn load_bias_tile( + bias: &Tensor>, + tile: &mut SharedMemory, + pos: &Positions, + #[comptime] gemm_settings: GemmSettings, +) { + let GemmSettings { + cmma_n, + cmma_m, + warp_size, + .. + } = gemm_settings; + + let vec = vectorization_of(bias); + let cmma_acc_tile_size = cmma_m * cmma_n; + let elems_per_thread = cmma_acc_tile_size / warp_size; + let start = pos.intra_warp_unit_idx * elems_per_thread; + let bias_tile_start = pos.cube_linear_warp_idx * cmma_acc_tile_size; + + #[unroll] + for n in range_stepped(0, elems_per_thread, vec) { + let n = n + start; + + let row = n % cmma_n + pos.global_n; + let value = bias[row / vec]; + + #[unroll] + for i in 0..vec { + tile[bias_tile_start + n + i] = value[i]; + } + } +} + pub(crate) fn can_do_implicit_gemm( input: &JitTensor, weight: &JitTensor,