Skip to content

[xegpu-to-vc-func] Is transpose_bit_width=16 supported? #895

Open
@dchigarev

Description

@dchigarev

I'm trying to transpose a 16x16xf16 matrix using xegpu.load_nd %0 {transpose = array<i64: 1, 0>, transpose_bit_width = 16 : i32} but the values are being transposed in the '32bit manner' (although transpose_bit_width=16). Is this an expected behavior or a bug?

Reproducer
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp  \
// RUN:                                       --runner imex-cpu-runner -e main \
// RUN:                                       --entry-point-result=void \
// RUN:                                       --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime

module attributes {gpu.container_module} {
  gpu.module @transpose_16bit_loadnd attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Bfloat16ConversionINTEL, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, StorageBuffer16BitAccess, VectorComputeINTEL, VectorAnyINTEL], [SPV_INTEL_bfloat16_conversion, SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_KHR_16bit_storage, SPV_NV_cooperative_matrix, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
    gpu.func @transpose_16bit_loadnd(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 1, 1, 1>, known_grid_size = array<i32: 2, 2, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
  
      %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
      %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>

      %2 = xegpu.load_nd %0 {transpose = array<i64: 1, 0>, transpose_bit_width = 16 : i32} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
      xegpu.store_nd %2, %1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
      gpu.return
    }
  }

  func.func @main() {
    %c_gen_int = arith.constant 0 : i1
    %cf_lower = arith.constant -0.5 : f32
    %cf_upper = arith.constant 0.5 : f32

    %result = memref.alloc() : memref<16x16xf16>
    %resultc = memref.alloc() : memref<16x16xf16>
    %r_r = memref.cast %result : memref<16x16xf16> to memref<*xf16>
    call @fillResource1DRandomF16(%r_r, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()

    %cast2 = memref.cast %result : memref<16x16xf16> to memref<*xf16>
    call @printMemrefF16(%cast2) : (memref<*xf16>) -> ()

    %gpu_result_index = gpu.alloc host_shared () : memref<16x16xf16>
    %gpu_result = gpu.alloc host_shared () : memref<16x16xf16>

    memref.copy %result, %gpu_result_index : memref<16x16xf16> to memref<16x16xf16>

    gpu.launch_func @transpose_16bit_loadnd::@transpose_16bit_loadnd blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)  args(%gpu_result_index : memref<16x16xf16>, %gpu_result : memref<16x16xf16>)
    memref.copy %gpu_result, %resultc : memref<16x16xf16> to memref<16x16xf16>
    %cast1 = memref.cast %resultc : memref<16x16xf16> to memref<*xf16>
    call @printMemrefF16(%cast1) : (memref<*xf16>) -> ()

    return
  }
  func.func private @printMemrefF16(memref<*xf16>)
  func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
}
Output
Original matrix:
[[-0.0335999,   -0.108459,   0.454346,   0.173096,   0.291992,   -0.437744,   0.150879,   0.243774,   -0.118896,   0.390625,   -0.337402,   0.184204,   0.148682,   0.109863,   0.131592,   0.167603], 
 [-0.383789,   0.182007,   0.157837,   0.016922,   0.403564,   -0.355957,   -0.465576,   -0.0371094,   -0.167603,   -0.0213776,   -0.107849,   -0.364502,   -0.49292,   -0.40625,   -0.474121,   0.259277], 
 [0.212158,   -0.0585938,   0.307861,   0.357178,   -0.0243073,   -0.301514,   0.157715,   0.0397949,   -0.115845,   -0.0805054,   0.354248,   0.288818,   -0.387695,   0.265137,   -0.191528,   0.23584], 
 [0.186035,   0.13623,   0.164795,   0.321777,   -0.131348,   0.189575,   0.437744,   -0.437256,   -0.488281,   0.104675,   0.223145,   0.468994,   0.471436,   0.289551,   -0.388184,   0.24231], 
 [0.152832,   -0.233521,   0.0818481,   -0.445312,   -0.0191803,   0.349854,   0.472168,   -0.358398,   -0.220459,   0.244751,   -0.0543518,   0.000132799,   0.288086,   0.0359192,   -0.0933838,   0.165527], 
 [-0.0643311,   -0.368896,   0.398438,   0.125854,   0.174438,   0.010788,   0.0161896,   -0.0637817,   -0.450928,   -0.256104,   -0.0791016,   0.197266,   -0.274658,   -0.172607,   -0.0960693,   0.376221], 
 [0.326416,   0.428223,   0.0844116,   -0.111023,   0.288574,   -0.287109,   0.147461,   0.489258,   -0.109314,   0.0188751,   -0.375732,   0.175903,   -0.309082,   -0.172852,   -0.499756,   -0.102051], 
 [-0.395508,   -0.160034,   -0.210571,   0.429688,   -0.302246,   -0.0577393,   -0.0242767,   -0.174194,   0.21228,   0.110107,   0.34082,   0.348877,   -0.255371,   0.156738,   0.143066,   -0.0538025], 
 [0.43042,   -0.496338,   0.0446472,   0.376465,   -0.153564,   -0.231934,   0.322266,   -0.2771,   -0.272949,   0.0265045,   0.293457,   -0.207886,   -0.248657,   -0.244141,   0.118164,   -0.167969], 
 [-0.318359,   0.33252,   0.192261,   -0.403564,   0.23877,   0.078064,   0.400391,   -0.290771,   -0.12323,   0.0836182,   0.265381,   -0.337891,   -0.431396,   0.262207,   0.0490723,   0.0157623], 
 [0.310059,   -0.481201,   -0.0360413,   0.371582,   0.39624,   0.413086,   0.307861,   0.499756,   0.0454102,   -0.2771,   0.352783,   -0.0714111,   0.184082,   0.4729,   -0.0998535,   -0.420166], 
 [0.445312,   -0.265381,   -0.182983,   -0.249146,   -0.437256,   -0.298828,   -0.418701,   -0.0429688,   0.0679932,   -0.256836,   -0.38208,   0.378174,   0.0784302,   -0.149658,   0.232544,   -0.249634], 
 [-0.318115,   -0.179443,   -0.33667,   -0.43335,   -0.129028,   0.0360718,   -0.48999,   0.333984,   0.356934,   0.238159,   -0.198608,   0.0809326,   0.0897827,   0.209839,   -0.0469055,   -0.409668], 
 [0.345947,   -0.0787354,   0.0486755,   0.098938,   0.0684204,   0.227295,   0.0414429,   0.465576,   0.204834,   0.419189,   0.297119,   -0.347412,   -0.0586853,   0.239746,   0.174805,   0.0572205], 
 [0.261963,   -0.0251923,   0.481201,   -0.470703,   -0.0614014,   0.305176,   -0.439209,   -0.0430603,   0.346924,   -0.411377,   0.00965118,   0.00774765,   -0.378906,   0.466309,   -0.13623,   -0.0748901], 
 [-0.241821,   0.000869751,   -0.336914,   -0.0773926,   0.469238,   -0.218994,   0.362305,   0.00957489,   -0.297852,   0.0365906,   -0.382568,   0.308594,   0.134277,   -0.322998,   -0.445557,   0.158325]]


Transposed matrix (with transpose_bit_width=16, but seems like it's still 32):
[[-0.0335999,   -0.108459,   -0.383789,   0.182007,   0.212158,   -0.0585938,   0.186035,   0.13623,   0.152832,   -0.233521,   -0.0643311,   -0.368896,   0.326416,   0.428223,   -0.395508,   -0.160034], 
 [0.43042,   -0.496338,   -0.318359,   0.33252,   0.310059,   -0.481201,   0.445312,   -0.265381,   -0.318115,   -0.179443,   0.345947,   -0.0787354,   0.261963,   -0.0251923,   -0.241821,   0.000869751], 
 [0.454346,   0.173096,   0.157837,   0.016922,   0.307861,   0.357178,   0.164795,   0.321777,   0.0818481,   -0.445312,   0.398438,   0.125854,   0.0844116,   -0.111023,   -0.210571,   0.429688], 
 [0.0446472,   0.376465,   0.192261,   -0.403564,   -0.0360413,   0.371582,   -0.182983,   -0.249146,   -0.33667,   -0.43335,   0.0486755,   0.098938,   0.481201,   -0.470703,   -0.336914,   -0.0773926], 
 [0.291992,   -0.437744,   0.403564,   -0.355957,   -0.0243073,   -0.301514,   -0.131348,   0.189575,   -0.0191803,   0.349854,   0.174438,   0.010788,   0.288574,   -0.287109,   -0.302246,   -0.0577393], 
 [-0.153564,   -0.231934,   0.23877,   0.078064,   0.39624,   0.413086,   -0.437256,   -0.298828,   -0.129028,   0.0360718,   0.0684204,   0.227295,   -0.0614014,   0.305176,   0.469238,   -0.218994], 
 [0.150879,   0.243774,   -0.465576,   -0.0371094,   0.157715,   0.0397949,   0.437744,   -0.437256,   0.472168,   -0.358398,   0.0161896,   -0.0637817,   0.147461,   0.489258,   -0.0242767,   -0.174194], 
 [0.322266,   -0.2771,   0.400391,   -0.290771,   0.307861,   0.499756,   -0.418701,   -0.0429688,   -0.48999,   0.333984,   0.0414429,   0.465576,   -0.439209,   -0.0430603,   0.362305,   0.00957489], 
 [-0.118896,   0.390625,   -0.167603,   -0.0213776,   -0.115845,   -0.0805054,   -0.488281,   0.104675,   -0.220459,   0.244751,   -0.450928,   -0.256104,   -0.109314,   0.0188751,   0.21228,   0.110107], 
 [-0.272949,   0.0265045,   -0.12323,   0.0836182,   0.0454102,   -0.2771,   0.0679932,   -0.256836,   0.356934,   0.238159,   0.204834,   0.419189,   0.346924,   -0.411377,   -0.297852,   0.0365906], 
 [-0.337402,   0.184204,   -0.107849,   -0.364502,   0.354248,   0.288818,   0.223145,   0.468994,   -0.0543518,   0.000132799,   -0.0791016,   0.197266,   -0.375732,   0.175903,   0.34082,   0.348877], 
 [0.293457,   -0.207886,   0.265381,   -0.337891,   0.352783,   -0.0714111,   -0.38208,   0.378174,   -0.198608,   0.0809326,   0.297119,   -0.347412,   0.00965118,   0.00774765,   -0.382568,   0.308594], 
 [0.148682,   0.109863,   -0.49292,   -0.40625,   -0.387695,   0.265137,   0.471436,   0.289551,   0.288086,   0.0359192,   -0.274658,   -0.172607,   -0.309082,   -0.172852,   -0.255371,   0.156738], 
 [-0.248657,   -0.244141,   -0.431396,   0.262207,   0.184082,   0.4729,   0.0784302,   -0.149658,   0.0897827,   0.209839,   -0.0586853,   0.239746,   -0.378906,   0.466309,   0.134277,   -0.322998], 
 [0.131592,   0.167603,   -0.474121,   0.259277,   -0.191528,   0.23584,   -0.388184,   0.24231,   -0.0933838,   0.165527,   -0.0960693,   0.376221,   -0.499756,   -0.102051,   0.143066,   -0.0538025], 
 [0.118164,   -0.167969,   0.0490723,   0.0157623,   -0.0998535,   -0.420166,   0.232544,   -0.249634,   -0.0469055,   -0.409668,   0.174805,   0.0572205,   -0.13623,   -0.0748901,   -0.445557,   0.158325]]

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions