@@ -51,3 +51,42 @@ module @eltwise_add attributes {gpu.container_module} {
51
51
}
52
52
func.func private @printMemrefBF16 (memref <*xbf16 >) attributes {llvm.emit_c_interface }
53
53
}
54
+
55
+ module @eltwise_add_usm attributes {gpu.container_module } {
56
+ memref.global " private" constant @__constant_10x20xbf16 : memref <10 x20 xbf16 > = dense <5.000000e-01 >
57
+ func.func @test (%arg0: memref <10 x20 xbf16 >, %arg1: memref <10 x20 xbf16 >) -> memref <10 x20 xbf16 > {
58
+ %c20 = arith.constant 20 : index
59
+ %c10 = arith.constant 10 : index
60
+ %c1 = arith.constant 1 : index
61
+ %memref = gpu.alloc host_shared () : memref <10 x20 xbf16 >
62
+ gpu.launch_func @test_kernel ::@test_kernel blocks in (%c10 , %c20 , %c1 ) threads in (%c1 , %c1 , %c1 ) args (%arg0 : memref <10 x20 xbf16 >, %arg1 : memref <10 x20 xbf16 >, %memref : memref <10 x20 xbf16 >)
63
+ %alloc = memref.alloc () : memref <10 x20 xbf16 >
64
+ memref.copy %memref , %alloc : memref <10 x20 xbf16 > to memref <10 x20 xbf16 >
65
+ gpu.dealloc %memref : memref <10 x20 xbf16 >
66
+ return %alloc : memref <10 x20 xbf16 >
67
+ }
68
+ gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env <#spirv.vce <v1.0 , [Addresses , Float16Buffer , Int64 , Int16 , Int8 , Kernel , Linkage , Vector16 , GenericPointer , Groups , Float16 , Float64 , AtomicFloat32AddEXT , ExpectAssumeKHR , SubgroupDispatch , VectorComputeINTEL , VectorAnyINTEL , Bfloat16ConversionINTEL ], [SPV_EXT_shader_atomic_float_add , SPV_KHR_expect_assume , SPV_INTEL_vector_compute , SPV_INTEL_bfloat16_conversion ]>, api =OpenCL , #spirv.resource_limits <>>} {
69
+ gpu.func @test_kernel (%arg0: memref <10 x20 xbf16 >, %arg1: memref <10 x20 xbf16 >, %arg2: memref <10 x20 xbf16 >) kernel attributes {VectorComputeFunctionINTEL , gpu.known_block_size = array<i32 : 1 , 1 , 1 >, gpu.known_grid_size = array<i32 : 10 , 20 , 1 >, spirv.entry_point_abi = #spirv.entry_point_abi <>} {
70
+ %block_id_x = gpu.block_id x
71
+ %block_id_y = gpu.block_id y
72
+ %cst = arith.constant 0.5 : bf16
73
+ %0 = memref.load %arg0 [%block_id_x , %block_id_y ] : memref <10 x20 xbf16 >
74
+ %1 = memref.load %arg1 [%block_id_x , %block_id_y ] : memref <10 x20 xbf16 >
75
+ %2 = arith.addf %0 , %1 : bf16
76
+ %3 = arith.addf %2 , %cst : bf16
77
+ memref.store %3 , %arg2 [%block_id_x , %block_id_y ] : memref <10 x20 xbf16 >
78
+ gpu.return
79
+ }
80
+ }
81
+ func.func @main () {
82
+ %0 = memref.get_global @__constant_10x20xbf16 : memref <10 x20 xbf16 >
83
+ %1 = memref.get_global @__constant_10x20xbf16 : memref <10 x20 xbf16 >
84
+ %2 = call @test (%0 , %1 ) : (memref <10 x20 xbf16 >, memref <10 x20 xbf16 >) -> memref <10 x20 xbf16 >
85
+ %cast = memref.cast %2 : memref <10 x20 xbf16 > to memref <*xbf16 >
86
+ // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
87
+ // CHECK-COUNT-200: 1.5
88
+ call @printMemrefbf16 (%cast ) : (memref <*xbf16 >) -> ()
89
+ return
90
+ }
91
+ func.func private @printMemrefBF16 (memref <*xbf16 >) attributes {llvm.emit_c_interface }
92
+ }
0 commit comments