Skip to content

Commit 0369d50

Browse files
Add an optimized kernel for attention
1 parent bcbbcdf commit 0369d50

File tree

2 files changed

+176
-3
lines changed

2 files changed

+176
-3
lines changed

src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,9 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
296296
// state.positionHolder,
297297
// layerIndex);
298298

299+
// global size = numberOfHeads * 8 = 16 * 8 = 128
299300
unifiedLayer.task("parallel-attention",
300-
TransformerComputeKernelsLayered::processHeadsFlashAttention,
301+
TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt,
301302
context,
302303
state.wrapQ,
303304
state.wrapKeyCache,
@@ -471,8 +472,8 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
471472
// Parallel attention worker configuration
472473
WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); // qwen ok
473474
// the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4
474-
parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 8, 1, 1);
475-
parallelAttentionWorker.setLocalWork(8, 1, 1); // Set local work size to 4 (for parallel attention)
475+
parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 32, 1, 1);
476+
parallelAttentionWorker.setLocalWork(32, 1, 1); // Set local work size to 4 (for parallel attention)
476477

477478
int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
478479
WorkerGrid matmul1Worker = new WorkerGrid1D(matmul1Global);

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,178 @@ public static void processHeadsFlashAttention(KernelContext context, FloatArray
417417
}
418418
}
419419

420+
/**
421+
* Same as processHeadsFlashAttention but with some optimizations
422+
* that seem to lower attention's execution time, especially in larger models.
423+
*/
424+
public static void processHeadsFlashAttentionOpt(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul,
425+
IntArray positionHolder, int layer, int contextLength) {
426+
427+
// Thread and workgroup information
428+
int tid = context.localIdx;
429+
int h = context.groupIdx; // Each workgroup processes one head
430+
int localSize = context.localGroupSizeX;
431+
432+
// Early exit if this workgroup is beyond our head count
433+
// This relies on the kernel being launched with nHeads workgroups.
434+
if (h >= nHeads) {
435+
return;
436+
}
437+
438+
int pos = positionHolder.get(0);
439+
int loff = layer * contextLength * kvDim;
440+
int kvHeadIdx = h / kvMul;
441+
int BLOCK_SIZE_C = 32;
442+
443+
// Allocate shared memory for tiled computation
444+
float[] q_shared = context.allocateFloatLocalArray(headSize);
445+
float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
446+
float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
447+
float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C);
448+
float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max
449+
450+
// Thread-local accumulators for online softmax
451+
float maxScore = Float.NEGATIVE_INFINITY;
452+
float sumExp = 0.0f;
453+
454+
// Thread-local output accumulation
455+
float[] output = new float[headSize];
456+
for (int i = 0; i < headSize; i++) {
457+
output[i] = 0.0f;
458+
}
459+
460+
// Load query vector into shared memory
461+
for (int i = tid; i < headSize; i += localSize) {
462+
q_shared[i] = q.get(h * headSize + i);
463+
}
464+
465+
context.localBarrier();
466+
467+
// Process sequence in tiles
468+
for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) {
469+
int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos);
470+
471+
// Load key and value vectors for this tile
472+
// Each thread loads a contiguous block of elements
473+
int totalElements = (tileEnd - tileC + 1) * headSize;
474+
int elementsPerThread = (totalElements + localSize - 1) / localSize;
475+
int startElem = tid * elementsPerThread;
476+
int endElem = Math.min(startElem + elementsPerThread, totalElements);
477+
478+
for (int globalElemIdx = startElem; globalElemIdx < endElem; globalElemIdx++) {
479+
// Convert flat index to (sequence_pos, dimension)
480+
int seqIdx = globalElemIdx / headSize;
481+
int dimIdx = globalElemIdx % headSize;
482+
483+
int tIdxInSeq = tileC + seqIdx;
484+
int tileMemOffset = seqIdx * headSize + dimIdx;
485+
486+
int kvCacheAbsolutePos = tIdxInSeq;
487+
int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + dimIdx;
488+
489+
k_tile[tileMemOffset] = key_cache.get(kvOffset);
490+
v_tile[tileMemOffset] = value_cache.get(kvOffset);
491+
}
492+
493+
context.localBarrier();
494+
495+
// Compute attention scores for this tile
496+
// Each thread computes one score for the tile
497+
for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
498+
int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile
499+
500+
float score = 0.0f;
501+
for (int d = 0; d < headSize; d++) {
502+
score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d];
503+
}
504+
score /= TornadoMath.sqrt(headSize);
505+
s_tile[score_idx_in_tile] = score;
506+
}
507+
508+
context.localBarrier();
509+
510+
// Allocate shared memory for reduction (needs to be power of 2)
511+
int reductionSize = 1024; // Should be >= BLOCK_SIZE_C and power of 2
512+
float[] reduction_shared = context.allocateFloatLocalArray(reductionSize);
513+
514+
// Step 1: Each thread finds max of its assigned subset
515+
int itemsPerThread = (BLOCK_SIZE_C + localSize - 1) / localSize;
516+
int startIdx = tid * itemsPerThread;
517+
int endIdx = Math.min(startIdx + itemsPerThread, tileEnd - tileC + 1);
518+
519+
float threadLocalMax = Float.NEGATIVE_INFINITY;
520+
for (int i = startIdx; i < endIdx; i++) {
521+
if (s_tile[i] > threadLocalMax) {
522+
threadLocalMax = s_tile[i];
523+
}
524+
}
525+
526+
// Step 2: Store each thread's local max in shared memory
527+
reduction_shared[tid] = threadLocalMax;
528+
context.localBarrier();
529+
530+
// Step 3: Parallel reduction tree
531+
for (int stride = localSize / 2; stride > 0; stride /= 2) {
532+
if (tid < stride && tid + stride < localSize) {
533+
reduction_shared[tid] = Math.max(reduction_shared[tid], reduction_shared[tid + stride]);
534+
}
535+
context.localBarrier();
536+
}
537+
538+
// Step 4: Thread 0 now has the final max
539+
float currentTileMax = reduction_shared[0];
540+
541+
542+
// Determine if we need to rescale previous results
543+
float newMax = Math.max(maxScore, currentTileMax);
544+
if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) {
545+
float scale = TornadoMath.exp(maxScore - newMax);
546+
sumExp *= scale;
547+
for (int d = 0; d < headSize; d++) {
548+
output[d] *= scale;
549+
}
550+
}
551+
maxScore = newMax;
552+
553+
// Process each key-value pair using original scores from s_tile
554+
// All threads iterate over all scores in the current tile
555+
for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) {
556+
// s_tile[t_idx_in_s_tile] now correctly refers to the original score
557+
float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore);
558+
sumExp += expScore;
559+
560+
for (int d = 0; d < headSize; d++) {
561+
output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d];
562+
}
563+
}
564+
context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load
565+
}
566+
567+
float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f;
568+
569+
int dimsPerThread = (headSize + localSize - 1) / localSize;
570+
int startDim = tid * dimsPerThread;
571+
int endDim = Math.min(startDim + dimsPerThread, headSize);
572+
int baseOffset = h * headSize + startDim;
573+
574+
// Process 4 elements at a time when possible
575+
int vectorEnd = startDim + ((endDim - startDim) & ~3); // Round down to multiple of 4
576+
577+
// Unrolled loop for better instruction-level parallelism
578+
for (int d = startDim; d < vectorEnd; d += 4) {
579+
int offset = d - startDim;
580+
xb.set(baseOffset + offset, output[d] * normFactor);
581+
xb.set(baseOffset + offset + 1, output[d + 1] * normFactor);
582+
xb.set(baseOffset + offset + 2, output[d + 2] * normFactor);
583+
xb.set(baseOffset + offset + 3, output[d + 3] * normFactor);
584+
}
585+
586+
// Handle remaining elements (0-3 elements)
587+
for (int d = vectorEnd; d < endDim; d++) {
588+
xb.set(h * headSize + d, output[d] * normFactor);
589+
}
590+
}
591+
420592
/**
421593
* Performs optimized matrix-vector multiplication where each work group
422594
* processes one row of the matrix.

0 commit comments

Comments
 (0)