Skip to content

Commit

Permalink
Initialize accumulator to bias for implicit GEMM to save an expensive…
Browse files Browse the repository at this point in the history
… `float_add` (#2383)
  • Loading branch information
wingertge authored Oct 20, 2024
1 parent 454c6f0 commit cb819ad
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 35 deletions.
41 changes: 33 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ tch = "0.15.0"
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7854b343975c990dd8bb1b4b68b3bc9bda488c1d" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7854b343975c990dd8bb1b4b68b3bc9bda488c1d" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "609183cccb26e2ccaa834f2230b5b8f38f9ea507" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "609183cccb26e2ccaa834f2230b5b8f38f9ea507" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ burn-common = { path = "../crates/burn-common", version = "0.15.0" }
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0", optional = true }
clap = { workspace = true }
colored = { workspace = true }
cubecl = { workspace = true, features = ["wgpu"] }
cubecl = { workspace = true, features = ["wgpu"], default-features = true }
derive-new = { workspace = true }
dirs = { workspace = true }
github-device-flow = { workspace = true }
Expand Down
129 changes: 105 additions & 24 deletions crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps},
ops::{conv::calculate_conv_output_size, ConvOptions},
Shape,
};
use cmma::{Matrix, MatrixIdent, MatrixLayout};
Expand All @@ -8,9 +8,12 @@ use half::f16;

use crate::{
kernel::{into_contiguous, slice},
ops::{numeric::empty_device, permute, reshape},
ops::{
numeric::{empty_device, zeros_device},
permute,
},
tensor::JitTensor,
FloatElement, IntElement, JitBackend, JitRuntime,
FloatElement, IntElement, JitRuntime,
};

/// Perform a 2D convolution using the implicit GEMM algorithm. Requires `cmma` to be available.
Expand All @@ -20,6 +23,7 @@ use crate::{
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
///
#[allow(clippy::extra_unused_type_parameters)]
pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
input: JitTensor<R, F>,
weight: JitTensor<R, F>,
Expand Down Expand Up @@ -78,6 +82,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(

let input_tile_size = cmma_m * cmma_k;
let weight_tile_size = cmma_k * cmma_n;
let acc_tile_size = cmma_m * cmma_n;

let warp_size = 32;
let warps_per_cube = (cube_dim_y * cube_dim_x) / warp_size;
Expand All @@ -90,6 +95,13 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
let weight_elems_per_thread = weight_tile_size / warp_size;
let weight_vectorization =
find_common_vec(in_channels, weight_elems_per_thread, supported_vecs);
let bias_elems_per_thread = acc_tile_size / warp_size;
let bias_vectorization = find_common_vec(out_channels, bias_elems_per_thread, supported_vecs);

let has_bias = bias.is_some();
let bias = bias.unwrap_or_else(|| {
zeros_device(input.client.clone(), input.device.clone(), Shape::new([1]))
});

let settings = GemmSettings {
cmma_m,
Expand Down Expand Up @@ -126,6 +138,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
cube_dim,
input.as_tensor_arg(input_vectorization),
weight.as_tensor_arg(weight_vectorization),
bias.as_tensor_arg(bias_vectorization),
out.as_tensor_arg(1),
DimensionsLaunch::new(
ScalarArg::new(gemm_m),
Expand All @@ -152,15 +165,11 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
padding_h: options.padding[0] as i32,
padding_w: options.padding[1] as i32,
aligned,
has_bias,
},
);

let mut out = slice(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]);

if let Some(bias) = bias {
let bias = reshape(bias, Shape::new([1, 1, 1, out_channels]));
out = JitBackend::<R, F, I>::float_add(out, bias);
}
let out = slice(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]);

// Reset to NCHW
permute(out, &[0, 3, 1, 2])
Expand Down Expand Up @@ -224,6 +233,7 @@ struct ConvSettings {
padding_h: i32,
padding_w: i32,
aligned: bool,
has_bias: bool,
}

#[derive(Clone, Copy, CubeType)]
Expand All @@ -247,12 +257,15 @@ struct Matrices<F: Float, FAcc: Float> {
fn implicit_gemm_kernel<F: Float, FMat: Float>(
input: &Tensor<Line<F>>,
weight: &Tensor<Line<F>>,
bias: &Tensor<Line<F>>,
out: &mut Tensor<F>,
dims: &Dimensions,
args: &ConvArgs,
#[comptime] gemm_settings: GemmSettings,
#[comptime] kernel_settings: ConvSettings,
#[comptime] conv_settings: ConvSettings,
) {
let _ = bias[0];

let GemmSettings {
cmma_m,
cmma_n,
Expand Down Expand Up @@ -283,18 +296,19 @@ fn implicit_gemm_kernel<F: Float, FMat: Float>(
let out_pos = pos.global_n + pos.global_m * dims.gemm_n;
let out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size);

if kernel_settings.aligned || pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n {
if conv_settings.aligned || pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n {
execute_gemm(
input,
weight,
bias,
out,
input_tile,
weight_tile,
dims,
&pos,
args,
gemm_settings,
kernel_settings,
conv_settings,
);
}
}
Expand Down Expand Up @@ -330,6 +344,7 @@ fn calculate_positions(#[comptime] gemm_settings: GemmSettings) -> Positions {
#[cube]
fn make_matrices<F: Float, FAcc: Float>(
#[comptime] gemm_settings: GemmSettings,
#[comptime] has_bias: bool,
) -> Matrices<F, FAcc> {
let GemmSettings {
cmma_m,
Expand All @@ -338,6 +353,27 @@ fn make_matrices<F: Float, FAcc: Float>(
..
} = gemm_settings;

let acc = if has_bias {
unsafe {
Matrix::<FAcc>::uninitialized(
MatrixIdent::Accumulator,
cmma_m,
cmma_n,
cmma_k,
MatrixLayout::Undefined,
)
}
} else {
Matrix::<FAcc>::from_value(
MatrixIdent::Accumulator,
cmma_m,
cmma_n,
cmma_k,
MatrixLayout::Undefined,
FAcc::new(0.0),
)
};

Matrices::<F, FAcc> {
a: unsafe {
Matrix::<F>::uninitialized(
Expand All @@ -357,21 +393,15 @@ fn make_matrices<F: Float, FAcc: Float>(
MatrixLayout::ColMajor,
)
},
acc: Matrix::<FAcc>::from_value(
MatrixIdent::Accumulator,
cmma_m,
cmma_n,
cmma_k,
MatrixLayout::Undefined,
FAcc::new(0.0),
),
acc,
}
}

#[cube]
fn execute_gemm<F: Float, FMat: Float>(
input: &Tensor<Line<F>>,
weight: &Tensor<Line<F>>,
bias: &Tensor<Line<F>>,
out: &mut SliceMut<F>,
input_tile: &mut SliceMut<FMat>,
weight_tile: &mut SliceMut<FMat>,
Expand All @@ -381,9 +411,26 @@ fn execute_gemm<F: Float, FMat: Float>(
#[comptime] g_settings: GemmSettings,
#[comptime] k_settings: ConvSettings,
) {
let GemmSettings { cmma_n, cmma_k, .. } = g_settings;

let matrices = make_matrices::<FMat, F>(g_settings);
let GemmSettings {
cmma_m,
cmma_n,
cmma_k,
warps_per_cube,
..
} = g_settings;
let has_bias = k_settings.has_bias;

let matrices = make_matrices::<FMat, F>(g_settings, has_bias);
if has_bias {
let mut smem_bias = SharedMemory::new(cmma_m * cmma_n * warps_per_cube);
load_bias_tile(bias, &mut smem_bias, pos, g_settings);
cmma::load_with_layout(
&matrices.acc,
smem_bias.as_slice(),
cmma_n,
MatrixLayout::RowMajor,
);
}

// Loop over the K-dimension
for k in range_stepped(0, dims.gemm_k, cmma_k) {
Expand Down Expand Up @@ -472,7 +519,7 @@ fn load_input_tile<F: Float, FMat: Float>(
// Slices are always `kernel_size * channels` elements wide so we can compute where inside a slice
// we are and also which row the slice is in relative to the start of the CMMA matrix

// Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is responsible for
// Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is repsonsible for
let my_slice_idx = (slice_start_idx + (m % cmma_k)) % dims.slice_size;

let channel = my_slice_idx % channels;
Expand Down Expand Up @@ -564,6 +611,40 @@ fn load_weight_tile<F: Float, FMat: Float>(
}
}

#[cube]
fn load_bias_tile<F: Float>(
bias: &Tensor<Line<F>>,
tile: &mut SharedMemory<F>,
pos: &Positions,
#[comptime] gemm_settings: GemmSettings,
) {
let GemmSettings {
cmma_n,
cmma_m,
warp_size,
..
} = gemm_settings;

let vec = vectorization_of(bias);
let cmma_acc_tile_size = cmma_m * cmma_n;
let elems_per_thread = cmma_acc_tile_size / warp_size;
let start = pos.intra_warp_unit_idx * elems_per_thread;
let bias_tile_start = pos.cube_linear_warp_idx * cmma_acc_tile_size;

#[unroll]
for n in range_stepped(0, elems_per_thread, vec) {
let n = n + start;

let row = n % cmma_n + pos.global_n;
let value = bias[row / vec];

#[unroll]
for i in 0..vec {
tile[bias_tile_start + n + i] = value[i];
}
}
}

pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
input: &JitTensor<R, E>,
weight: &JitTensor<R, E>,
Expand Down

0 comments on commit cb819ad

Please sign in to comment.