@@ -417,6 +417,178 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray
417417 }
418418 }
419419
420+ /**
421+ * Same as processHeadsFlashAttention but with some optimizations
422+ * that seem to lower attention's execution time, especially in larger models.
423+ */
424+ public static void processHeadsFlashAttentionOpt (KernelContext context , FloatArray q , FloatArray key_cache , FloatArray value_cache , FloatArray xb , int nHeads , int headSize , int kvDim , int kvMul ,
425+ IntArray positionHolder , int layer , int contextLength ) {
426+
427+ // Thread and workgroup information
428+ int tid = context .localIdx ;
429+ int h = context .groupIdx ; // Each workgroup processes one head
430+ int localSize = context .localGroupSizeX ;
431+
432+ // Early exit if this workgroup is beyond our head count
433+ // This relies on the kernel being launched with nHeads workgroups.
434+ if (h >= nHeads ) {
435+ return ;
436+ }
437+
438+ int pos = positionHolder .get (0 );
439+ int loff = layer * contextLength * kvDim ;
440+ int kvHeadIdx = h / kvMul ;
441+ int BLOCK_SIZE_C = 32 ;
442+
443+ // Allocate shared memory for tiled computation
444+ float [] q_shared = context .allocateFloatLocalArray (headSize );
445+ float [] k_tile = context .allocateFloatLocalArray (BLOCK_SIZE_C * headSize );
446+ float [] v_tile = context .allocateFloatLocalArray (BLOCK_SIZE_C * headSize );
447+ float [] s_tile = context .allocateFloatLocalArray (BLOCK_SIZE_C );
448+ float [] shared_tile_max_holder = context .allocateFloatLocalArray (1 ); // FIX: For broadcasting tile max
449+
450+ // Thread-local accumulators for online softmax
451+ float maxScore = Float .NEGATIVE_INFINITY ;
452+ float sumExp = 0.0f ;
453+
454+ // Thread-local output accumulation
455+ float [] output = new float [headSize ];
456+ for (int i = 0 ; i < headSize ; i ++) {
457+ output [i ] = 0.0f ;
458+ }
459+
460+ // Load query vector into shared memory
461+ for (int i = tid ; i < headSize ; i += localSize ) {
462+ q_shared [i ] = q .get (h * headSize + i );
463+ }
464+
465+ context .localBarrier ();
466+
467+ // Process sequence in tiles
468+ for (int tileC = 0 ; tileC <= pos ; tileC += BLOCK_SIZE_C ) {
469+ int tileEnd = Math .min (tileC + BLOCK_SIZE_C - 1 , pos );
470+
471+ // Load key and value vectors for this tile
472+ // Each thread loads a contiguous block of elements
473+ int totalElements = (tileEnd - tileC + 1 ) * headSize ;
474+ int elementsPerThread = (totalElements + localSize - 1 ) / localSize ;
475+ int startElem = tid * elementsPerThread ;
476+ int endElem = Math .min (startElem + elementsPerThread , totalElements );
477+
478+ for (int globalElemIdx = startElem ; globalElemIdx < endElem ; globalElemIdx ++) {
479+ // Convert flat index to (sequence_pos, dimension)
480+ int seqIdx = globalElemIdx / headSize ;
481+ int dimIdx = globalElemIdx % headSize ;
482+
483+ int tIdxInSeq = tileC + seqIdx ;
484+ int tileMemOffset = seqIdx * headSize + dimIdx ;
485+
486+ int kvCacheAbsolutePos = tIdxInSeq ;
487+ int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + dimIdx ;
488+
489+ k_tile [tileMemOffset ] = key_cache .get (kvOffset );
490+ v_tile [tileMemOffset ] = value_cache .get (kvOffset );
491+ }
492+
493+ context .localBarrier ();
494+
495+ // Compute attention scores for this tile
496+ // Each thread computes one score for the tile
497+ for (int tIdxInSeq = tileC + tid ; tIdxInSeq <= tileEnd ; tIdxInSeq += localSize ) {
498+ int score_idx_in_tile = tIdxInSeq - tileC ; // 0, 1, 2, or 3 for this tile
499+
500+ float score = 0.0f ;
501+ for (int d = 0 ; d < headSize ; d ++) {
502+ score += q_shared [d ] * k_tile [score_idx_in_tile * headSize + d ];
503+ }
504+ score /= TornadoMath .sqrt (headSize );
505+ s_tile [score_idx_in_tile ] = score ;
506+ }
507+
508+ context .localBarrier ();
509+
510+ // Allocate shared memory for reduction (needs to be power of 2)
511+ int reductionSize = 1024 ; // Should be >= BLOCK_SIZE_C and power of 2
512+ float [] reduction_shared = context .allocateFloatLocalArray (reductionSize );
513+
514+ // Step 1: Each thread finds max of its assigned subset
515+ int itemsPerThread = (BLOCK_SIZE_C + localSize - 1 ) / localSize ;
516+ int startIdx = tid * itemsPerThread ;
517+ int endIdx = Math .min (startIdx + itemsPerThread , tileEnd - tileC + 1 );
518+
519+ float threadLocalMax = Float .NEGATIVE_INFINITY ;
520+ for (int i = startIdx ; i < endIdx ; i ++) {
521+ if (s_tile [i ] > threadLocalMax ) {
522+ threadLocalMax = s_tile [i ];
523+ }
524+ }
525+
526+ // Step 2: Store each thread's local max in shared memory
527+ reduction_shared [tid ] = threadLocalMax ;
528+ context .localBarrier ();
529+
530+ // Step 3: Parallel reduction tree
531+ for (int stride = localSize / 2 ; stride > 0 ; stride /= 2 ) {
532+ if (tid < stride && tid + stride < localSize ) {
533+ reduction_shared [tid ] = Math .max (reduction_shared [tid ], reduction_shared [tid + stride ]);
534+ }
535+ context .localBarrier ();
536+ }
537+
538+ // Step 4: Thread 0 now has the final max
539+ float currentTileMax = reduction_shared [0 ];
540+
541+
542+ // Determine if we need to rescale previous results
543+ float newMax = Math .max (maxScore , currentTileMax );
544+ if (newMax != maxScore && maxScore != Float .NEGATIVE_INFINITY ) {
545+ float scale = TornadoMath .exp (maxScore - newMax );
546+ sumExp *= scale ;
547+ for (int d = 0 ; d < headSize ; d ++) {
548+ output [d ] *= scale ;
549+ }
550+ }
551+ maxScore = newMax ;
552+
553+ // Process each key-value pair using original scores from s_tile
554+ // All threads iterate over all scores in the current tile
555+ for (int t_idx_in_s_tile = 0 ; t_idx_in_s_tile <= tileEnd - tileC ; t_idx_in_s_tile ++) {
556+ // s_tile[t_idx_in_s_tile] now correctly refers to the original score
557+ float expScore = TornadoMath .exp (s_tile [t_idx_in_s_tile ] - maxScore );
558+ sumExp += expScore ;
559+
560+ for (int d = 0 ; d < headSize ; d ++) {
561+ output [d ] += expScore * v_tile [t_idx_in_s_tile * headSize + d ];
562+ }
563+ }
564+ context .localBarrier (); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load
565+ }
566+
567+ float normFactor = (sumExp > 0.0f ) ? (1.0f / sumExp ) : 0.0f ;
568+
569+ int dimsPerThread = (headSize + localSize - 1 ) / localSize ;
570+ int startDim = tid * dimsPerThread ;
571+ int endDim = Math .min (startDim + dimsPerThread , headSize );
572+ int baseOffset = h * headSize + startDim ;
573+
574+ // Process 4 elements at a time when possible
575+ int vectorEnd = startDim + ((endDim - startDim ) & ~3 ); // Round down to multiple of 4
576+
577+ // Unrolled loop for better instruction-level parallelism
578+ for (int d = startDim ; d < vectorEnd ; d += 4 ) {
579+ int offset = d - startDim ;
580+ xb .set (baseOffset + offset , output [d ] * normFactor );
581+ xb .set (baseOffset + offset + 1 , output [d + 1 ] * normFactor );
582+ xb .set (baseOffset + offset + 2 , output [d + 2 ] * normFactor );
583+ xb .set (baseOffset + offset + 3 , output [d + 3 ] * normFactor );
584+ }
585+
586+ // Handle remaining elements (0-3 elements)
587+ for (int d = vectorEnd ; d < endDim ; d ++) {
588+ xb .set (h * headSize + d , output [d ] * normFactor );
589+ }
590+ }
591+
420592 /**
421593 * Performs optimized matrix-vector multiplication where each work group
422594 * processes one row of the matrix.
0 commit comments