Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add ggml_soft_max_ext #4256

Merged
merged 14 commits into from
Dec 1, 2023
Merged
2 changes: 1 addition & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
}

LOG_TEE("\n");
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG_TEE("\n");

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
Expand Down
2 changes: 1 addition & 1 deletion ggml-alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {

#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
size_t cur_max = (char*)addr - (char*)alloc->data + size;
size_t cur_max = (char*)addr - (char*)alloc->base + size;
if (cur_max > alloc->max_size) {
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) {
Expand Down
130 changes: 87 additions & 43 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
Expand Down Expand Up @@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;

static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};

static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}

static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}

static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
}

static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

Expand Down Expand Up @@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] * x[i];
}

static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}

template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
}
}

static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}

template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
}

// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int block_size = blockDim.y;
const int tid = threadIdx.y;
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension

const int block_size = blockDim.x;

const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;

__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];

float max_val = -INFINITY;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
max_val = max(max_val, x[i]);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
}

// find the max value in the block
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = -INFINITY;
}
__syncthreads();

if (lane_id == 0) {
buf[warp_id] = max_val;
}
__syncthreads();

max_val = buf[lane_id];
max_val = warp_reduce_max(max_val);
}

float tmp = 0.f;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const float val = expf(x[i] - max_val);
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
tmp += val;
dst[i] = val;
dst[ix] = val;
}

// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf[lane_id] = 0.f;
}
__syncthreads();

if (lane_id == 0) {
buf[warp_id] = tmp;
}
__syncthreads();

tmp = buf[lane_id];
tmp = warp_reduce_sum(tmp);
}

const float inv_tmp = 1.f / tmp;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const int i = rowx*ncols + col;
dst[i] *= inv_tmp;
}
}
Expand Down Expand Up @@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}

static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
const dim3 block_dims(1, WARP_SIZE, 1);
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}

static void im2col_f32_f16_cuda(const float * x, half * dst,
Expand Down Expand Up @@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional

const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;

soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
float scale = 1.0f;
memcpy(&scale, dst->op_params, sizeof(float));

soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_scale(
Expand Down
43 changes: 27 additions & 16 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1028,20 +1028,27 @@ void ggml_metal_graph_compute(
int nth = 32; // SIMD width

if (ne00%4 == 0) {
while (nth < ne00/4 && nth < 256) {
nth *= 2;
}
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else {
do {
while (nth < ne00 && nth < 1024) {
nth *= 2;
} while (nth <= ne00 && nth <= 1024);
nth /= 2;
}
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];

const float scale = ((float *) dst->op_params)[0];

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
Expand Down Expand Up @@ -1351,15 +1358,19 @@ void ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

const int nth = MIN(512, ne00);
int nth = 32; // SIMD width

while (nth < ne00/4 && nth < 1024) {
nth *= 2;
}

[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];

const int64_t nrows = ggml_nrows(src0);

Expand Down
Loading
Loading