Skip to content

Commit 6f0b716

Browse files
committed
add bf16 test
1 parent aecf9ab commit 6f0b716

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,42 @@ module @eltwise_add attributes {gpu.container_module} {
5151
}
5252
func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface}
5353
}
54+
55+
module @eltwise_add_usm attributes {gpu.container_module} {
56+
memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<5.000000e-01>
57+
func.func @test(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
58+
%c20 = arith.constant 20 : index
59+
%c10 = arith.constant 10 : index
60+
%c1 = arith.constant 1 : index
61+
%memref = gpu.alloc host_shared () : memref<10x20xbf16>
62+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%arg0 : memref<10x20xbf16>, %arg1 : memref<10x20xbf16>, %memref : memref<10x20xbf16>)
63+
%alloc = memref.alloc() : memref<10x20xbf16>
64+
memref.copy %memref, %alloc : memref<10x20xbf16> to memref<10x20xbf16>
65+
gpu.dealloc %memref : memref<10x20xbf16>
66+
return %alloc : memref<10x20xbf16>
67+
}
68+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
69+
gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 10, 20, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
70+
%block_id_x = gpu.block_id x
71+
%block_id_y = gpu.block_id y
72+
%cst = arith.constant 0.5 : bf16
73+
%0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16>
74+
%1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16>
75+
%2 = arith.addf %0, %1 : bf16
76+
%3 = arith.addf %2, %cst : bf16
77+
memref.store %3, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>
78+
gpu.return
79+
}
80+
}
81+
func.func @main() {
82+
%0 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16>
83+
%1 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16>
84+
%2 = call @test(%0, %1) : (memref<10x20xbf16>, memref<10x20xbf16>) -> memref<10x20xbf16>
85+
%cast = memref.cast %2 : memref<10x20xbf16> to memref<*xbf16>
86+
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
87+
// CHECK-COUNT-200: 1.5
88+
call @printMemrefbf16(%cast) : (memref<*xbf16>) -> ()
89+
return
90+
}
91+
func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface}
92+
}

0 commit comments

Comments
 (0)