Open
Description
I'm trying to transpose a 16x16xf16
matrix using xegpu.load_nd %0 {transpose = array<i64: 1, 0>, transpose_bit_width = 16 : i32}
but the values are being transposed in the '32bit manner' (although transpose_bit_width=16
). Is this an expected behavior or a bug?
Reproducer
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime
module attributes {gpu.container_module} {
gpu.module @transpose_16bit_loadnd attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Bfloat16ConversionINTEL, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, StorageBuffer16BitAccess, VectorComputeINTEL, VectorAnyINTEL], [SPV_INTEL_bfloat16_conversion, SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_KHR_16bit_storage, SPV_NV_cooperative_matrix, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @transpose_16bit_loadnd(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 1, 1, 1>, known_grid_size = array<i32: 2, 2, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
%1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
%2 = xegpu.load_nd %0 {transpose = array<i64: 1, 0>, transpose_bit_width = 16 : i32} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
xegpu.store_nd %2, %1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
gpu.return
}
}
func.func @main() {
%c_gen_int = arith.constant 0 : i1
%cf_lower = arith.constant -0.5 : f32
%cf_upper = arith.constant 0.5 : f32
%result = memref.alloc() : memref<16x16xf16>
%resultc = memref.alloc() : memref<16x16xf16>
%r_r = memref.cast %result : memref<16x16xf16> to memref<*xf16>
call @fillResource1DRandomF16(%r_r, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
%cast2 = memref.cast %result : memref<16x16xf16> to memref<*xf16>
call @printMemrefF16(%cast2) : (memref<*xf16>) -> ()
%gpu_result_index = gpu.alloc host_shared () : memref<16x16xf16>
%gpu_result = gpu.alloc host_shared () : memref<16x16xf16>
memref.copy %result, %gpu_result_index : memref<16x16xf16> to memref<16x16xf16>
gpu.launch_func @transpose_16bit_loadnd::@transpose_16bit_loadnd blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%gpu_result_index : memref<16x16xf16>, %gpu_result : memref<16x16xf16>)
memref.copy %gpu_result, %resultc : memref<16x16xf16> to memref<16x16xf16>
%cast1 = memref.cast %resultc : memref<16x16xf16> to memref<*xf16>
call @printMemrefF16(%cast1) : (memref<*xf16>) -> ()
return
}
func.func private @printMemrefF16(memref<*xf16>)
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
}
Output
Original matrix:
[[-0.0335999, -0.108459, 0.454346, 0.173096, 0.291992, -0.437744, 0.150879, 0.243774, -0.118896, 0.390625, -0.337402, 0.184204, 0.148682, 0.109863, 0.131592, 0.167603],
[-0.383789, 0.182007, 0.157837, 0.016922, 0.403564, -0.355957, -0.465576, -0.0371094, -0.167603, -0.0213776, -0.107849, -0.364502, -0.49292, -0.40625, -0.474121, 0.259277],
[0.212158, -0.0585938, 0.307861, 0.357178, -0.0243073, -0.301514, 0.157715, 0.0397949, -0.115845, -0.0805054, 0.354248, 0.288818, -0.387695, 0.265137, -0.191528, 0.23584],
[0.186035, 0.13623, 0.164795, 0.321777, -0.131348, 0.189575, 0.437744, -0.437256, -0.488281, 0.104675, 0.223145, 0.468994, 0.471436, 0.289551, -0.388184, 0.24231],
[0.152832, -0.233521, 0.0818481, -0.445312, -0.0191803, 0.349854, 0.472168, -0.358398, -0.220459, 0.244751, -0.0543518, 0.000132799, 0.288086, 0.0359192, -0.0933838, 0.165527],
[-0.0643311, -0.368896, 0.398438, 0.125854, 0.174438, 0.010788, 0.0161896, -0.0637817, -0.450928, -0.256104, -0.0791016, 0.197266, -0.274658, -0.172607, -0.0960693, 0.376221],
[0.326416, 0.428223, 0.0844116, -0.111023, 0.288574, -0.287109, 0.147461, 0.489258, -0.109314, 0.0188751, -0.375732, 0.175903, -0.309082, -0.172852, -0.499756, -0.102051],
[-0.395508, -0.160034, -0.210571, 0.429688, -0.302246, -0.0577393, -0.0242767, -0.174194, 0.21228, 0.110107, 0.34082, 0.348877, -0.255371, 0.156738, 0.143066, -0.0538025],
[0.43042, -0.496338, 0.0446472, 0.376465, -0.153564, -0.231934, 0.322266, -0.2771, -0.272949, 0.0265045, 0.293457, -0.207886, -0.248657, -0.244141, 0.118164, -0.167969],
[-0.318359, 0.33252, 0.192261, -0.403564, 0.23877, 0.078064, 0.400391, -0.290771, -0.12323, 0.0836182, 0.265381, -0.337891, -0.431396, 0.262207, 0.0490723, 0.0157623],
[0.310059, -0.481201, -0.0360413, 0.371582, 0.39624, 0.413086, 0.307861, 0.499756, 0.0454102, -0.2771, 0.352783, -0.0714111, 0.184082, 0.4729, -0.0998535, -0.420166],
[0.445312, -0.265381, -0.182983, -0.249146, -0.437256, -0.298828, -0.418701, -0.0429688, 0.0679932, -0.256836, -0.38208, 0.378174, 0.0784302, -0.149658, 0.232544, -0.249634],
[-0.318115, -0.179443, -0.33667, -0.43335, -0.129028, 0.0360718, -0.48999, 0.333984, 0.356934, 0.238159, -0.198608, 0.0809326, 0.0897827, 0.209839, -0.0469055, -0.409668],
[0.345947, -0.0787354, 0.0486755, 0.098938, 0.0684204, 0.227295, 0.0414429, 0.465576, 0.204834, 0.419189, 0.297119, -0.347412, -0.0586853, 0.239746, 0.174805, 0.0572205],
[0.261963, -0.0251923, 0.481201, -0.470703, -0.0614014, 0.305176, -0.439209, -0.0430603, 0.346924, -0.411377, 0.00965118, 0.00774765, -0.378906, 0.466309, -0.13623, -0.0748901],
[-0.241821, 0.000869751, -0.336914, -0.0773926, 0.469238, -0.218994, 0.362305, 0.00957489, -0.297852, 0.0365906, -0.382568, 0.308594, 0.134277, -0.322998, -0.445557, 0.158325]]
Transposed matrix (with transpose_bit_width=16, but seems like it's still 32):
[[-0.0335999, -0.108459, -0.383789, 0.182007, 0.212158, -0.0585938, 0.186035, 0.13623, 0.152832, -0.233521, -0.0643311, -0.368896, 0.326416, 0.428223, -0.395508, -0.160034],
[0.43042, -0.496338, -0.318359, 0.33252, 0.310059, -0.481201, 0.445312, -0.265381, -0.318115, -0.179443, 0.345947, -0.0787354, 0.261963, -0.0251923, -0.241821, 0.000869751],
[0.454346, 0.173096, 0.157837, 0.016922, 0.307861, 0.357178, 0.164795, 0.321777, 0.0818481, -0.445312, 0.398438, 0.125854, 0.0844116, -0.111023, -0.210571, 0.429688],
[0.0446472, 0.376465, 0.192261, -0.403564, -0.0360413, 0.371582, -0.182983, -0.249146, -0.33667, -0.43335, 0.0486755, 0.098938, 0.481201, -0.470703, -0.336914, -0.0773926],
[0.291992, -0.437744, 0.403564, -0.355957, -0.0243073, -0.301514, -0.131348, 0.189575, -0.0191803, 0.349854, 0.174438, 0.010788, 0.288574, -0.287109, -0.302246, -0.0577393],
[-0.153564, -0.231934, 0.23877, 0.078064, 0.39624, 0.413086, -0.437256, -0.298828, -0.129028, 0.0360718, 0.0684204, 0.227295, -0.0614014, 0.305176, 0.469238, -0.218994],
[0.150879, 0.243774, -0.465576, -0.0371094, 0.157715, 0.0397949, 0.437744, -0.437256, 0.472168, -0.358398, 0.0161896, -0.0637817, 0.147461, 0.489258, -0.0242767, -0.174194],
[0.322266, -0.2771, 0.400391, -0.290771, 0.307861, 0.499756, -0.418701, -0.0429688, -0.48999, 0.333984, 0.0414429, 0.465576, -0.439209, -0.0430603, 0.362305, 0.00957489],
[-0.118896, 0.390625, -0.167603, -0.0213776, -0.115845, -0.0805054, -0.488281, 0.104675, -0.220459, 0.244751, -0.450928, -0.256104, -0.109314, 0.0188751, 0.21228, 0.110107],
[-0.272949, 0.0265045, -0.12323, 0.0836182, 0.0454102, -0.2771, 0.0679932, -0.256836, 0.356934, 0.238159, 0.204834, 0.419189, 0.346924, -0.411377, -0.297852, 0.0365906],
[-0.337402, 0.184204, -0.107849, -0.364502, 0.354248, 0.288818, 0.223145, 0.468994, -0.0543518, 0.000132799, -0.0791016, 0.197266, -0.375732, 0.175903, 0.34082, 0.348877],
[0.293457, -0.207886, 0.265381, -0.337891, 0.352783, -0.0714111, -0.38208, 0.378174, -0.198608, 0.0809326, 0.297119, -0.347412, 0.00965118, 0.00774765, -0.382568, 0.308594],
[0.148682, 0.109863, -0.49292, -0.40625, -0.387695, 0.265137, 0.471436, 0.289551, 0.288086, 0.0359192, -0.274658, -0.172607, -0.309082, -0.172852, -0.255371, 0.156738],
[-0.248657, -0.244141, -0.431396, 0.262207, 0.184082, 0.4729, 0.0784302, -0.149658, 0.0897827, 0.209839, -0.0586853, 0.239746, -0.378906, 0.466309, 0.134277, -0.322998],
[0.131592, 0.167603, -0.474121, 0.259277, -0.191528, 0.23584, -0.388184, 0.24231, -0.0933838, 0.165527, -0.0960693, 0.376221, -0.499756, -0.102051, 0.143066, -0.0538025],
[0.118164, -0.167969, 0.0490723, 0.0157623, -0.0998535, -0.420166, 0.232544, -0.249634, -0.0469055, -0.409668, 0.174805, 0.0572205, -0.13623, -0.0748901, -0.445557, 0.158325]]