Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add ggml_soft_max_ext #4256

Merged
merged 14 commits into from
Dec 1, 2023
Merged
2 changes: 1 addition & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
}

LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
Expand Down
78 changes: 51 additions & 27 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_SOFT_MAX_BLOCK_SIZE 512
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
Expand Down Expand Up @@ -4717,45 +4718,62 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
}

// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int block_size = blockDim.y;
const int tid = threadIdx.y;
// TODO: maybe can be improved with some warp-based primitives
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension

const int block_size = blockDim.x;

__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced the old WARP_SIZE kernel with a new that uses up to CUDA_SOFT_MAX_BLOCK_SIZE=512 threads per block (based on ne00). This seems to perform better on V100.

Any reason to prefer the old kernel?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Block wide reductions are a lot more expensive than warp reductions, so most of the kernels only distribute operations that need a reduction to a warp. A block reduction is usually done by doing a reduction in every warp and storing the result of each warp to shared memory, and then doing a reduction of the shared memory.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use warp-based block reduction with max of 32 warps. Results improved, more significantly for the largest contexts.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to do similar optimization for the Metal kernel


float max_val = -INFINITY;
buf[tid] = -INFINITY;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
max_val = max(max_val, x[i]);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f));
}

__syncthreads();

// find the max value in the block
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
for (int i = block_size/2; i > 0; i >>= 1) {
if (tid < i) {
buf[tid] = max(buf[tid], buf[tid + i]);
}
__syncthreads();
}

float tmp = 0.f;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const float val = expf(x[i] - max_val);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]);
tmp += val;
dst[i] = val;
dst[ix] = val;
}

__syncthreads();

buf[tid] = tmp;

__syncthreads();

// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
for (int i = block_size/2; i > 0; i >>= 1) {
if (tid < i) {
buf[tid] += buf[tid + i];
}
__syncthreads();
}

const float inv_tmp = 1.f / tmp;
const float inv_tmp = 1.f / buf[0];

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
}
}
Expand Down Expand Up @@ -5792,10 +5810,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}

static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
const dim3 block_dims(1, WARP_SIZE, 1);
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}

static void im2col_f32_f16_cuda(const float * x, half * dst,
Expand Down Expand Up @@ -6846,14 +6866,18 @@ inline void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional

const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;

soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));

soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_scale(
Expand Down
15 changes: 10 additions & 5 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1036,11 +1036,16 @@ void ggml_metal_graph_compute(
nth /= 2;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];

const float scale = ((float *) dst->op_params)[0];

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
Expand Down
26 changes: 16 additions & 10 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ kernel void kernel_gelu(

kernel void kernel_soft_max(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
Expand All @@ -194,14 +196,15 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]);
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}

float max = simd_max(lmax);
Expand All @@ -225,7 +228,7 @@ kernel void kernel_soft_max(
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp(psrc0[i00] - max);
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// wish to compute it twice.
Expand Down Expand Up @@ -257,10 +260,12 @@ kernel void kernel_soft_max(

kernel void kernel_soft_max_4(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
Expand All @@ -271,14 +276,15 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

// parallel max
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;

for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]);
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
}

const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
Expand All @@ -303,7 +309,7 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
Expand Down
Loading
Loading