|
1 |
| -// RUN: gc-opt --split-input-file --flash-attention-conversion %s |
| 1 | +// RUN: gc-opt --split-input-file --flash-attention-conversion --gc-cpu-pipeline %s | gc-cpu-runner -e main -entry-point-result=void |
2 | 2 |
|
3 | 3 | func.func @flash_attention(%arg0: tensor<1x16x384x64xf32>, %arg1: tensor<1x16x384x64xf32>, %arg2: tensor<1x16x384x64xf32>, %arg3: tensor<1x16x384x384xf32>) -> tensor<1x16x384x64xf32> {
|
4 | 4 | %0 = tensor.empty() : tensor<1x16x384x64xf32>
|
5 | 5 | %1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>) outs(%0 : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
|
6 | 6 | return %1 : tensor<1x16x384x64xf32>
|
7 | 7 | }
|
| 8 | + |
| 9 | +func.func @main() { |
| 10 | + %cst = arith.constant 1.000000e+00 : f32 |
| 11 | + |
| 12 | + %QKVShape = tensor.empty() : tensor<1x16x384x64xf32> |
| 13 | + %maskShape = tensor.empty() : tensor<1x16x384x384xf32> |
| 14 | + |
| 15 | + %Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32> |
| 16 | + %K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32> |
| 17 | + %V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32> |
| 18 | + %mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<1x16x384x384xf32>) -> tensor<1x16x384x384xf32> |
| 19 | + |
| 20 | + %out = func.call @flash_attention(%Q, %K, %V, %mask) : |
| 21 | + (tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>) |
| 22 | + -> (tensor<1x16x384x64xf32>) |
| 23 | + |
| 24 | + %idx = arith.constant 0 : index |
| 25 | + %val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<1x16x384x64xf32> |
| 26 | + cpuruntime.printf "output[0, 0, 0]: %f\n" %val : f32 |
| 27 | + |
| 28 | + return |
| 29 | +} |
| 30 | +// CHECK: output[0, 0, 0]: 1.0 |
| 31 | + |
0 commit comments