Skip to content

Commit

Permalink
finetune : speed-up ggml_compute_forward_out_prod_f32 via BLAS (#4079)
Browse files Browse the repository at this point in the history
* Remove logically superfluous assertions and order by dimension

* Use cblas_sgemm() to implement ggml_compute_forward_out_prod()

* Remove ggml_compute_forward_out_prod_use_blas(), fix compiling errors on cmake/zig, remove trailing whitespace

* Add openBLAS support for sgemm() in compute_forward_out_prod()
  • Loading branch information
gwjr authored Nov 17, 2023
1 parent 947f64f commit 3e916a0
Showing 1 changed file with 61 additions and 8 deletions.
69 changes: 61 additions & 8 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -9611,10 +9611,12 @@ static void ggml_compute_forward_out_prod_f32(
const int ith = params->ith;
const int nth = params->nth;

GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne03 == ne13);

// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == sizeof(float));
Expand All @@ -9625,18 +9627,25 @@ static void ggml_compute_forward_out_prod_f32(
// GGML_ASSERT(nb1 <= nb2);
// GGML_ASSERT(nb2 <= nb3);

GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne10);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);

// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows

// TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
// TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
// TODO: #if defined(GGML_USE_CLBLAST)

#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
bool use_blas = ggml_is_matrix(src0) &&
ggml_is_matrix(src1) &&
ggml_is_contiguous(src0) &&
(ggml_is_contiguous(src1) || ggml_is_transposed(src1));
#endif

if (params->type == GGML_TASK_INIT) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst
if (use_blas) {
return;
}
#endif
ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
return;
}
Expand All @@ -9645,6 +9654,50 @@ static void ggml_compute_forward_out_prod_f32(
return;
}

#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (use_blas) {
if (params->ith != 0) { // All threads other than the first do no work.
return;
}
// Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
// src0: (k,n)
// src1: (k,m)
// dst: (m,n)
//
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
// Also expressed as (major,minor)
// a: (m,k): so src1 transposed
// b: (k,n): so src0
// c: (m,n)
//
// However, if ggml_is_transposed(src1) is true, then
// src1->data already contains a transposed version, so sgemm mustn't
// transpose it further.

int n = src0->ne[0];
int k = src0->ne[1];
int m = src1->ne[0];

int transposeA, lda;

if (!ggml_is_transposed(src1)) {
transposeA = CblasTrans;
lda = m;
} else {
transposeA = CblasNoTrans;
lda = k;
}

float * a = (float *) ((char *) src1->data);
float * b = (float *) ((char *) src0->data);
float * c = (float *) ((char *) dst->data);

cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);

return;
}
#endif

// dst[:,:,:,:] = 0
// for i2,i3:
// for i1:
Expand Down

0 comments on commit 3e916a0

Please sign in to comment.