Skip to content

Commit 0457df5

Browse files
committed
update unit test
1 parent 69b08a5 commit 0457df5

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

lib/gc/Transforms/Pipeline.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ void populateCPURuntimePasses(mlir::PassManager &pm) {
104104
}
105105

106106
void populateLoweringToLLVMPasses(mlir::PassManager &pm) {
107+
pm.addPass(createLowerAffinePass());
107108
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
108109
pm.addPass(createConvertSCFToCFPass());
109110
pm.addPass(cpuruntime::createCPURuntimeToLLVM());

test/gc/Transform/flashAttention.mlir

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,31 @@
1-
// RUN: gc-opt --split-input-file --flash-attention-conversion %s
1+
// RUN: gc-opt --split-input-file --flash-attention-conversion --gc-cpu-pipeline %s | gc-cpu-runner -e main -entry-point-result=void
22

33
func.func @flash_attention(%arg0: tensor<1x16x384x64xf32>, %arg1: tensor<1x16x384x64xf32>, %arg2: tensor<1x16x384x64xf32>, %arg3: tensor<1x16x384x384xf32>) -> tensor<1x16x384x64xf32> {
44
%0 = tensor.empty() : tensor<1x16x384x64xf32>
55
%1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>) outs(%0 : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
66
return %1 : tensor<1x16x384x64xf32>
77
}
8+
9+
func.func @main() {
10+
%cst = arith.constant 1.000000e+00 : f32
11+
12+
%QKVShape = tensor.empty() : tensor<1x16x384x64xf32>
13+
%maskShape = tensor.empty() : tensor<1x16x384x384xf32>
14+
15+
%Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
16+
%K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
17+
%V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
18+
%mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<1x16x384x384xf32>) -> tensor<1x16x384x384xf32>
19+
20+
%out = func.call @flash_attention(%Q, %K, %V, %mask) :
21+
(tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>)
22+
-> (tensor<1x16x384x64xf32>)
23+
24+
%idx = arith.constant 0 : index
25+
%val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<1x16x384x64xf32>
26+
cpuruntime.printf "output[0, 0, 0]: %f\n" %val : f32
27+
28+
return
29+
}
30+
// CHECK: output[0, 0, 0]: 1.0
31+

0 commit comments

Comments
 (0)