2323#include " singleton.h"
2424#include " timeline.h"
2525
26- extern bool enableCATMLP;
27- extern bool enableCBLASMLP;
28- void setMLPOPTConfig ();
2926// C++ implementation for the python code in modeling_llama.py:
3027// residual = hidden_states
3128// hidden_states = self.post_attention_layernorm(hidden_states)
@@ -65,8 +62,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
6562 ctx->mmHelper ->convertWeight (ctx, trans, hiddenSize, imSize, upW, upS, upZ, true , quantizedUpWeight,
6663 upWeightScale, upWeightZero, upWeightSum);
6764
68- setMLPOPTConfig ();
69- if (!enableCATMLP) {
65+ if (!enableCATMLP ()) {
7066 gateWeight.Resize (hiddenSize, it.second - it.first );
7167 upWeight.Resize (hiddenSize, it.second - it.first );
7268 ctx->mmHelper ->packWeight (trans, quantizedGateWeight, gateWeight);
@@ -82,14 +78,9 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
8278 ctx->mmHelper ->packWeight (trans, quantizedCatWeights, catWeights);
8379 }
8480 // Horizontally split the down weight
85- if (enableCBLASMLP && std::is_same_v<WeiT, bfloat16_t >) {
86- ctx->mmHelper ->convertWeight (ctx, trans, imSize, hiddenSize, downW, downS, downZ, false , downWeight,
87- downWeightScale, downWeightZero, downWeightSum);
88- } else {
89- ctx->mmHelper ->convertWeight (ctx, trans, imSize, hiddenSize, downW, downS, downZ, false ,
90- quantizedDownWeight, downWeightScale, downWeightZero, downWeightSum);
91- ctx->mmHelper ->packWeight (trans, quantizedDownWeight, downWeight);
92- }
81+ ctx->mmHelper ->convertWeight (ctx, trans, imSize, hiddenSize, downW, downS, downZ, false ,
82+ quantizedDownWeight, downWeightScale, downWeightZero, downWeightSum);
83+ ctx->mmHelper ->packWeight (trans, quantizedDownWeight, downWeight);
9384
9485#ifdef DEBUG
9586 dbg.debugPrint (" quantizedGateWeight:\n " );
@@ -137,7 +128,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
137128 dbg.dumpMatrix (normBuffer);
138129#endif
139130
140- if (!enableCATMLP) {
131+ if (!enableCATMLP () ) {
141132 hpj::Matrix<ImT> imBuffer (
142133 (ImT *)ctx->imOut .Data (), ctx->imOut .Rows (), ctx->imOut .Cols (), ctx->imOut .Stride ());
143134 gateProj (ctx, doLnBefore ? normBuffer : inBuffer, imBuffer);
@@ -165,31 +156,19 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
165156 hpj::Matrix<ImT> imBuffer ((ImT *)ctx->imOut .Data (), M, N, N);
166157
167158 // Need to allocate extra buffer as oneDNN does not support the case of stride > cols
168- if constexpr (std::is_same_v<ImT, bfloat16_t >) {
169- const int cols = N / 2 ;
170- auto bufSize = M * cols * sizeof (ImT);
171- ImT *t = (ImT *)SimpleMemPool::instance ().getBuffer (" mlp_silu" , bufSize);
172- hpj::Matrix<ImT> siluBuf (t, M, cols, cols);
173-
174- catGateUpProj (ctx, doLnBefore ? normBuffer : inBuffer, imBuffer, siluBuf);
175- #ifdef DEBUG
176- dbg.debugPrint (" gateUp output:\n " );
177- dbg.dumpMatrix (siluBuf);
178- #endif
179- downProj (ctx, siluBuf, outBuffer, inBuffer, ctx->splitIdx == 0 );
180- }
159+ const int cols = N / 2 ;
160+ auto bufSize = M * cols * sizeof (ImT);
161+ ImT *t = (ImT *)SimpleMemPool::instance ().getBuffer (" mlp_silu" , bufSize);
162+ hpj::Matrix<ImT> siluBuf (t, M, cols, cols);
181163
182- // Use imBuffer as silu buffer
183- else {
184- catGateUpProj (ctx, doLnBefore ? normBuffer : inBuffer, imBuffer, imBuffer);
164+ catGateUpProj (ctx, doLnBefore ? normBuffer : inBuffer, imBuffer, siluBuf);
185165#ifdef DEBUG
186- dbg.debugPrint (" catWeights:\n " );
187- dbg.dumpMatrix (catWeights);
188- dbg.debugPrint (" gateUp output:\n " );
189- dbg.dumpMatrix (imBuffer );
166+ dbg.debugPrint (" catWeights:\n " );
167+ dbg.dumpMatrix (catWeights);
168+ dbg.debugPrint (" gateUp output:\n " );
169+ dbg.dumpMatrix (siluBuf );
190170#endif
191- downProj (ctx, imBuffer, outBuffer, inBuffer, ctx->splitIdx == 0 );
192- }
171+ downProj (ctx, siluBuf, outBuffer, inBuffer, ctx->splitIdx == 0 );
193172 }
194173
195174#ifdef DEBUG
@@ -248,7 +227,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
248227 TimeLine t (" DownProj" );
249228
250229 assert (input.Rows () == output.Rows ());
251- if (!enableCATMLP)
230+ if (!enableCATMLP () )
252231 assert (input.Cols () == downWeight.Rows ());
253232 else
254233 assert (input.Cols () == 2 * downWeight.Rows ());
@@ -266,62 +245,10 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
266245 const InT *R = residential.Data ();
267246
268247 if (isMaster) {
269- // TODO: enable below code (currently disabled as hard to get tmpBuf from pre-alloced memory)
270- // if (enableCBLASMLP && std::is_same_v<WeiT, bfloat16_t>) {
271- // computeProjBF16(A, B, C, M, N, K, lda, ldc, ldc, R, ldr, tmpBuf, ldt);
272- // }
273- {
274- ctx->mmHelper ->compute_residential (
275- false , M, N, K, 1 .0f , A, lda, B, scaleB, zeroB, sumB, 0 .0f , C, ldc, NULL , R, ldr);
276- }
248+ ctx->mmHelper ->compute_residential (
249+ false , M, N, K, 1 .0f , A, lda, B, scaleB, zeroB, sumB, 0 .0f , C, ldc, NULL , R, ldr);
277250 } else {
278- // if (enableCBLASMLP && std::is_same_v<WeiT, bfloat16_t>) {
279- // computeProjBF16(A, B, C, M, N, K, lda, ldc, ldc, nullptr, 0, tmpBuf, ldt);
280- // }
281- {
282- ctx->mmHelper ->compute (false , M, N, K, 1 .0f , A, lda, B, scaleB, zeroB, sumB, 0 .0f , C, ldc);
283- }
284- }
285- }
286-
287- // C = (R == nullptr ? A * B : A * B + R)
288- // T: temporary buffer if C is not in float
289- void computeProjBF16 (const ImT *A, const WeiT *B, OutT *C, int M, int N, int K, int lda, int ldb, int ldc,
290- const InT *R, int ldr, float *T, int ldt) {
291- int alpha = 1.0 ;
292- int beta = 0.0 ;
293-
294- // MKL needs float as output, use T (temporary buffer) as output if C is not in float
295- float *D = std::is_same_v<OutT, float > ? (float *)C : T;
296- int ldd = std::is_same_v<OutT, float > ? ldc : ldt;
297-
298- REQUIRES (D != nullptr , " Incorrect parameter in computeProjBF16." );
299-
300- if (R != nullptr ) {
301- #pragma omp parallel for
302- for (uint64_t i = 0 ; i < M; ++i) {
303- xft::copy (D + i * ldd, R + i * ldr, N);
304- }
305- beta = 1.0 ;
306- }
307-
308- int ldaH = lda * sizeof (ImT) / sizeof (bfloat16_t ); // stride in bf16
309- if constexpr (std::is_same_v<ImT, float >) {
310- #pragma omp parallel for
311- for (uint64_t i = 0 ; i < M; ++i) {
312- bfloat16_t::cvt_float_to_bfloat16 (A + i * lda, (bfloat16_t *)A + i * ldaH, K);
313- }
314- }
315-
316- cblas_gemm_bf16bf16f32 (CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, (const MKL_BF16 *)(A), ldaH,
317- (const MKL_BF16 *)(B), ldb, beta, D, ldd);
318-
319- // Convert result from float to OutT
320- if constexpr (!std::is_same_v<OutT, float >) {
321- #pragma omp parallel for
322- for (uint64_t i = 0 ; i < M; ++i) {
323- xft::copy (C + i * ldc, D + i * ldd, N);
324- }
251+ ctx->mmHelper ->compute (false , M, N, K, 1 .0f , A, lda, B, scaleB, zeroB, sumB, 0 .0f , C, ldc);
325252 }
326253 }
327254
0 commit comments