Skip to content

Extra memref.copy introduced by bufferization #288

Closed
@zhczhong

Description

@zhczhong
gc-opt --gc-cpu-pipeline test.mlir

bufferization.materialize_in_destination

bufferization.materialize_in_destination will be materialized as a memref.copy instead of inplace writing on the output buffer. memref.copy %alloc_1, %arg3 : memref<128x1024xf32> to memref<128x1024xf32> could be converted to a inplace write

module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 1 : i32>>>} {
  func.func @entry(%arg0: memref<1024x1024xf32>, %arg1: memref<128x1024xf32>, %arg2: memref<128x1024xf32>, %arg3: memref<128x1024xf32>) {
    %0 = bufferization.to_tensor %arg0 restrict : memref<1024x1024xf32>
    %1 = bufferization.to_tensor %arg1 restrict : memref<128x1024xf32>
    %2 = bufferization.to_tensor %arg2 restrict : memref<128x1024xf32>
    %3 = tensor.empty() : tensor<128x1024xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
    %5 = linalg.matmul ins(%1, %0 : tensor<128x1024xf32>, tensor<1024x1024xf32>) outs(%4 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
    %6 = tensor.empty() : tensor<128x1024xf32>
    %7 = linalg.add ins(%5, %2 : tensor<128x1024xf32>, tensor<128x1024xf32>) outs(%6 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
    %8 = tensor.empty() : tensor<128x1024xf32>
    %cst_0 = arith.constant 0.000000e+00 : f32
    %9 = linalg.fill ins(%cst_0 : f32) outs(%8 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
    %10 = linalg.max ins(%7, %9 : tensor<128x1024xf32>, tensor<128x1024xf32>) outs(%8 : tensor<128x1024xf32>) -> tensor<128x1024xf32>
    bufferization.materialize_in_destination %10 in restrict writable %arg3 : (tensor<128x1024xf32>, memref<128x1024xf32>) -> ()
    return
  }
}

After bufferization

#map = affine_map<()[s0, s1] -> (s0 + s1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map3 = affine_map<(d0, d1) -> ()>
#map4 = affine_map<(d0, d1) -> (d0, d1)>
#map5 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map6 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map7 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
module @fragment_name attributes {"#dlti.sys_spec" = #dlti.target_system_spec<"CPU" : #dlti.target_device_spec<#dlti.dl_entry<"tile_size", 32 : i32>, #dlti.dl_entry<"num_threads", 1 : i32>>>} {
  func.func @entry(%arg0: memref<1024x1024xf32>, %arg1: memref<128x1024xf32>, %arg2: memref<128x1024xf32>, %arg3: memref<128x1024xf32>) {
    %c16 = arith.constant 16 : index
    %c128 = arith.constant 128 : index
    %c256 = arith.constant 256 : index
    %c1024 = arith.constant 1024 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x1024xf32>
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<16x16x16xf32>
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<128x1024xf32>
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<128x256xf32>
    scf.for %arg4 = %c0 to %c1024 step %c256 {
      %subview = memref.subview %alloc[0, %arg4] [128, 256] [1, 1] : memref<128x1024xf32> to memref<128x256xf32, strided<[1024, 1], offset: ?>>
      %subview_3 = memref.subview %arg2[0, %arg4] [128, 256] [1, 1] : memref<128x1024xf32> to memref<128x256xf32, strided<[1024, 1], offset: ?>>
      %subview_4 = memref.subview %alloc_1[0, %arg4] [128, 256] [1, 1] : memref<128x1024xf32> to memref<128x256xf32, strided<[1024, 1], offset: ?>>
      scf.for %arg5 = %c0 to %c1024 step %c256 {
        %0 = arith.cmpi eq, %arg5, %c0 : index
        scf.for %arg6 = %c0 to %c128 step %c16 {
          %subview_5 = memref.subview %arg1[%arg6, %arg5] [16, 256] [1, 1] : memref<128x1024xf32> to memref<16x256xf32, strided<[1024, 1], offset: ?>>
          %subview_6 = memref.subview %subview[%arg6, 0] [16, 256] [1, 1] : memref<128x256xf32, strided<[1024, 1], offset: ?>> to memref<16x256xf32, strided<[1024, 1], offset: ?>>
          %subview_7 = memref.subview %subview_4[%arg6, 0] [16, 256] [1, 1] : memref<128x256xf32, strided<[1024, 1], offset: ?>> to memref<16x256xf32, strided<[1024, 1], offset: ?>>
          %expand_shape = memref.expand_shape %subview_5 [[0], [1, 2]] output_shape [16, 16, 16] : memref<16x256xf32, strided<[1024, 1], offset: ?>> into memref<16x16x16xf32, strided<[1024, 16, 1], offset: ?>>
          scf.for %arg7 = %c0 to %c256 step %c16 {
            %1 = affine.apply #map()[%arg7, %arg4]
            %subview_8 = memref.subview %arg0[%arg5, %1] [256, 16] [1, 1] : memref<1024x1024xf32> to memref<256x16xf32, strided<[1024, 1], offset: ?>>
            %subview_9 = memref.subview %subview_6[0, %arg7] [16, 16] [1, 1] : memref<16x256xf32, strided<[1024, 1], offset: ?>> to memref<16x16xf32, strided<[1024, 1], offset: ?>>
            linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expand_shape : memref<16x16x16xf32, strided<[1024, 16, 1], offset: ?>>) outs(%alloc_0 : memref<16x16x16xf32>) {
            ^bb0(%in: f32, %out: f32):
              linalg.yield %in : f32
            }
            %expand_shape_10 = memref.expand_shape %subview_8 [[0, 1], [2]] output_shape [16, 16, 16] : memref<256x16xf32, strided<[1024, 1], offset: ?>> into memref<16x16x16xf32, strided<[16384, 1024, 1], offset: ?>>
            scf.if %0 {
              linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%subview_9 : memref<16x16xf32, strided<[1024, 1], offset: ?>>) {
              ^bb0(%in: f32, %out: f32):
                linalg.yield %in : f32
              }
              linalg.generic {indexing_maps = [#map5, #map6, #map7], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%alloc_0, %expand_shape_10 : memref<16x16x16xf32>, memref<16x16x16xf32, strided<[16384, 1024, 1], offset: ?>>) outs(%subview_9 : memref<16x16xf32, strided<[1024, 1], offset: ?>>) {
              ^bb0(%in: f32, %in_14: f32, %out: f32):
                %2 = arith.mulf %in, %in_14 : f32
                %3 = arith.addf %out, %2 : f32
                linalg.yield %3 : f32
              }
            } else {
              linalg.generic {indexing_maps = [#map5, #map6, #map7], iterator_types = ["reduction", "parallel", "parallel", "reduction"]} ins(%alloc_0, %expand_shape_10 : memref<16x16x16xf32>, memref<16x16x16xf32, strided<[16384, 1024, 1], offset: ?>>) outs(%subview_9 : memref<16x16xf32, strided<[1024, 1], offset: ?>>) {
              ^bb0(%in: f32, %in_14: f32, %out: f32):
                %2 = arith.mulf %in, %in_14 : f32
                %3 = arith.addf %out, %2 : f32
                linalg.yield %3 : f32
              }
            }
            %subview_11 = memref.subview %subview_3[%arg6, %arg7] [16, 16] [1, 1] : memref<128x256xf32, strided<[1024, 1], offset: ?>> to memref<16x16xf32, strided<[1024, 1], offset: ?>>
            linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%subview_9, %subview_11 : memref<16x16xf32, strided<[1024, 1], offset: ?>>, memref<16x16xf32, strided<[1024, 1], offset: ?>>) outs(%subview_9 : memref<16x16xf32, strided<[1024, 1], offset: ?>>) {
            ^bb0(%in: f32, %in_14: f32, %out: f32):
              %2 = arith.addf %in, %in_14 : f32
              linalg.yield %2 : f32
            }
            %subview_12 = memref.subview %alloc_2[%arg6, %arg7] [16, 16] [1, 1] : memref<128x256xf32> to memref<16x16xf32, strided<[256, 1], offset: ?>>
            linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%subview_12 : memref<16x16xf32, strided<[256, 1], offset: ?>>) {
            ^bb0(%in: f32, %out: f32):
              linalg.yield %in : f32
            }
            %subview_13 = memref.subview %subview_7[0, %arg7] [16, 16] [1, 1] : memref<16x256xf32, strided<[1024, 1], offset: ?>> to memref<16x16xf32, strided<[1024, 1], offset: ?>>
            linalg.generic {indexing_maps = [#map4, #map4, #map4], iterator_types = ["parallel", "parallel"]} ins(%subview_9, %subview_12 : memref<16x16xf32, strided<[1024, 1], offset: ?>>, memref<16x16xf32, strided<[256, 1], offset: ?>>) outs(%subview_13 : memref<16x16xf32, strided<[1024, 1], offset: ?>>) {
            ^bb0(%in: f32, %in_14: f32, %out: f32):
              %2 = arith.cmpf ugt, %in, %in_14 : f32
              %3 = arith.select %2, %in, %in_14 : f32
              %4 = arith.cmpf uno, %in_14, %in_14 : f32
              %5 = arith.select %4, %in_14, %3 : f32
              linalg.yield %5 : f32
            }
          }
        }
      }
    }
    memref.dealloc %alloc_2 : memref<128x256xf32>
    memref.dealloc %alloc_0 : memref<16x16x16xf32>
    memref.dealloc %alloc : memref<128x1024xf32>
    memref.copy %alloc_1, %arg3 : memref<128x1024xf32> to memref<128x1024xf32>
    memref.dealloc %alloc_1 : memref<128x1024xf32>
    return
  }
}

extract_slice

memref.copy %alloc_3, %subview_1 : memref<?x4x32x32xbf16> to memref<?x4x32x32xbf16, strided<[32768, 1024, 32, 1], offset: ?>> should be reducant

func.func @main_entry(%arg0: tensor<32x32x32x32xbf16>, %arg1: tensor<32x32x16x32x2xbf16>) -> tensor<32x32x32x32xbf16> attributes {llvm.emit_c_interface} {
  %c1 = arith.constant 1 : index
  %c4 = arith.constant 4 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : bf16
  %0 = tensor.empty() : tensor<32x32x32x32xbf16>
  %1 = scf.forall (%arg2) = (0) to (32) step (5) shared_outs(%arg3 = %0) -> (tensor<32x32x32x32xbf16>) {
    %2 = affine.min affine_map<(d0) -> (-d0 + 32, 5)>(%arg2)
    %extracted_slice = tensor.extract_slice %arg3[%arg2, 0, 0, 0] [%2, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
    %3 = scf.forall (%arg4) = (0) to (32) step (4) shared_outs(%arg5 = %extracted_slice) -> (tensor<?x32x32x32xbf16>) {
      %extracted_slice_0 = tensor.extract_slice %arg5[0, %arg4, 0, 0] [%2, 4, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x4x32x32xbf16>
      %4 = scf.for %arg6 = %c0 to %2 step %c4 iter_args(%arg7 = %extracted_slice_0) -> (tensor<?x4x32x32xbf16>) {
        %5 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%arg6)[%2]
        %extracted_slice_1 = tensor.extract_slice %arg7[%arg6, 0, 0, 0] [%5, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xbf16> to tensor<?x4x32x32xbf16>
        %extracted_slice_2 = tensor.extract_slice %arg7[%arg6, 0, 0, 0] [%5, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xbf16> to tensor<?x4x32x32xbf16>
        %6 = tensor.empty(%5) : tensor<?x4x32x32xf32>
        %extracted_slice_3 = tensor.extract_slice %6[0, 0, 0, 0] [%5, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xf32> to tensor<?x4x32x32xf32>
        %extracted_slice_4 = tensor.extract_slice %arg7[%arg6, 0, 0, 0] [%5, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xbf16> to tensor<?x4x32x32xbf16>
        %7:2 = scf.for %arg8 = %c0 to %5 step %c1 iter_args(%arg9 = %extracted_slice_3, %arg10 = %extracted_slice_4) -> (tensor<?x4x32x32xf32>, tensor<?x4x32x32xbf16>) {
          %8 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>()[%arg2, %arg8, %arg6]
          %extracted_slice_7 = tensor.extract_slice %arg0[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<1x32x32x32xbf16>
          %extracted_slice_8 = tensor.extract_slice %arg9[%arg8, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xf32> to tensor<1x4x32x32xf32>
          %extracted_slice_9 = tensor.extract_slice %arg10[%arg8, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xbf16> to tensor<1x4x32x32xbf16>
          %9:2 = scf.for %arg11 = %c0 to %c4 step %c1 iter_args(%arg12 = %extracted_slice_8, %arg13 = %extracted_slice_9) -> (tensor<1x4x32x32xf32>, tensor<1x4x32x32xbf16>) {
            %collapsed = tensor.collapse_shape %extracted_slice_7 [[0, 1], [2], [3]] : tensor<1x32x32x32xbf16> into tensor<32x32x32xbf16>
            %10 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg4, %arg11]
            %extracted_slice_12 = tensor.extract_slice %arg1[%10, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : tensor<32x32x16x32x2xbf16> to tensor<1x32x16x32x2xbf16>
            %collapsed_13 = tensor.collapse_shape %extracted_slice_12 [[0, 1], [2], [3], [4]] : tensor<1x32x16x32x2xbf16> into tensor<32x16x32x2xbf16>
            %extracted_slice_14 = tensor.extract_slice %arg12[0, %arg11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x4x32x32xf32> to tensor<1x1x32x32xf32>
            %collapsed_15 = tensor.collapse_shape %extracted_slice_14 [[0, 1, 2], [3]] : tensor<1x1x32x32xf32> into tensor<32x32xf32>
            %extracted_slice_16 = tensor.extract_slice %arg13[0, %arg11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x4x32x32xbf16> to tensor<1x1x32x32xbf16>
            %collapsed_17 = tensor.collapse_shape %extracted_slice_16 [[0, 1, 2], [3]] : tensor<1x1x32x32xbf16> into tensor<32x32xbf16>
            %11 = linalg.fill ins(%cst : bf16) outs(%collapsed_15 : tensor<32x32xf32>) -> tensor<32x32xf32>
            %12 = linalgx.batch_reduce_matmul_vnni ins(%collapsed, %collapsed_13 : tensor<32x32x32xbf16>, tensor<32x16x32x2xbf16>) outs(%11 : tensor<32x32xf32>) -> tensor<32x32xf32>
            %13 = linalg.copy ins(%12 : tensor<32x32xf32>) outs(%collapsed_17 : tensor<32x32xbf16>) -> tensor<32x32xbf16>
            %expanded = tensor.expand_shape %12 [[0, 1, 2], [3]] output_shape [1, 1, 32, 32] : tensor<32x32xf32> into tensor<1x1x32x32xf32>
            %inserted_slice_18 = tensor.insert_slice %expanded into %arg12[0, %arg11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x1x32x32xf32> into tensor<1x4x32x32xf32>
            %expanded_19 = tensor.expand_shape %13 [[0, 1, 2], [3]] output_shape [1, 1, 32, 32] : tensor<32x32xbf16> into tensor<1x1x32x32xbf16>
            %inserted_slice_20 = tensor.insert_slice %expanded_19 into %arg13[0, %arg11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<1x1x32x32xbf16> into tensor<1x4x32x32xbf16>
            scf.yield %inserted_slice_18, %inserted_slice_20 : tensor<1x4x32x32xf32>, tensor<1x4x32x32xbf16>
          }
          %inserted_slice_10 = tensor.insert_slice %9#0 into %arg9[%arg8, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : tensor<1x4x32x32xf32> into tensor<?x4x32x32xf32>
          %inserted_slice_11 = tensor.insert_slice %9#1 into %arg10[%arg8, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : tensor<1x4x32x32xbf16> into tensor<?x4x32x32xbf16>
          scf.yield %inserted_slice_10, %inserted_slice_11 : tensor<?x4x32x32xf32>, tensor<?x4x32x32xbf16>
        }
        %inserted_slice = tensor.insert_slice %7#1 into %extracted_slice_2[0, 0, 0, 0] [%5, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xbf16> into tensor<?x4x32x32xbf16>
        %inserted_slice_5 = tensor.insert_slice %inserted_slice into %extracted_slice_1[0, 0, 0, 0] [%5, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xbf16> into tensor<?x4x32x32xbf16>
        %inserted_slice_6 = tensor.insert_slice %inserted_slice_5 into %arg7[%arg6, 0, 0, 0] [%5, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xbf16> into tensor<?x4x32x32xbf16>
        scf.yield %inserted_slice_6 : tensor<?x4x32x32xbf16>
      }
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %4 into %arg5[0, %arg4, 0, 0] [%2, 4, 32, 32] [1, 1, 1, 1] : tensor<?x4x32x32xbf16> into tensor<?x32x32x32xbf16>
      }
    }
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %3 into %arg3[%arg2, 0, 0, 0] [%2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
    }
  }
  return %1 : tensor<32x32x32x32xbf16>
}

After bufferization

#map = affine_map<(d0) -> (-d0 + 32, 5)>
#map1 = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
#map2 = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
#map3 = affine_map<()[s0, s1] -> (s0 + s1)>
#map4 = affine_map<(d0, d1) -> ()>
#map5 = affine_map<(d0, d1) -> (d0, d1)>
#map6 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d0, d3 * 2 + d4)>
#map7 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1, d4)>
#map8 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
module {
  func.func @main_entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<32x32x32x32xbf16>) attributes {llvm.emit_c_interface} {
    %c1 = arith.constant 1 : index
    %c4 = arith.constant 4 : index
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : bf16
    scf.forall (%arg3) = (0) to (32) step (5) {
      %0 = affine.min #map(%arg3)
      %subview = memref.subview %arg2[%arg3, 0, 0, 0] [%0, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<?x32x32x32xbf16, strided<[32768, 1024, 32, 1], offset: ?>>
      scf.forall (%arg4) = (0) to (32) step (4) {
        %subview_0 = memref.subview %subview[0, %arg4, 0, 0] [%0, 4, 32, 32] [1, 1, 1, 1] : memref<?x32x32x32xbf16, strided<[32768, 1024, 32, 1], offset: ?>> to memref<?x4x32x32xbf16, strided<[32768, 1024, 32, 1], offset: ?>>
        scf.for %arg5 = %c0 to %0 step %c4 {
          %1 = affine.min #map1(%arg5)[%0]
          %subview_1 = memref.subview %subview_0[%arg5, 0, 0, 0] [%1, 4, 32, 32] [1, 1, 1, 1] : memref<?x4x32x32xbf16, strided<[32768, 1024, 32, 1], offset: ?>> to memref<?x4x32x32xbf16, strided<[32768, 1024, 32, 1], offset: ?>>
          %alloc = memref.alloc(%1) {alignment = 64 : i64} : memref<?x4x32x32xf32>
          %subview_2 = memref.subview %alloc[0, 0, 0, 0] [%1, 4, 32, 32] [1, 1, 1, 1] : memref<?x4x32x32xf32> to memref<?x4x32x32xf32, strided<[4096, 1024, 32, 1]>>
          %alloc_3 = memref.alloc(%1) {alignment = 64 : i64} : memref<?x4x32x32xbf16>
          memref.copy %subview_1, %alloc_3 : memref<?x4x32x32xbf16, strided<[32768, 1024, 32, 1], offset: ?>> to memref<?x4x32x32xbf16>
          scf.for %arg6 = %c0 to %1 step %c1 {
            %2 = affine.apply #map2()[%arg3, %arg6, %arg5]
            %subview_4 = memref.subview %subview_2[%arg6, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<?x4x32x32xf32, strided<[4096, 1024, 32, 1]>> to memref<1x4x32x32xf32, strided<[4096, 1024, 32, 1], offset: ?>>
            %subview_5 = memref.subview %alloc_3[%arg6, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<?x4x32x32xbf16> to memref<1x4x32x32xbf16, strided<[4096, 1024, 32, 1], offset: ?>>
            %subview_6 = memref.subview %arg0[%2, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
            scf.for %arg7 = %c0 to %c4 step %c1 {
              %3 = affine.apply #map3()[%arg4, %arg7]
              %subview_7 = memref.subview %arg1[%3, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
              %subview_8 = memref.subview %subview_4[0, %arg7, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<1x4x32x32xf32, strided<[4096, 1024, 32, 1], offset: ?>> to memref<32x32xf32, strided<[32, 1], offset: ?>>
              %subview_9 = memref.subview %subview_5[0, %arg7, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<1x4x32x32xbf16, strided<[4096, 1024, 32, 1], offset: ?>> to memref<32x32xbf16, strided<[32, 1], offset: ?>>
              linalg.generic {indexing_maps = [#map4, #map5], iterator_types = ["parallel", "parallel"]} ins(%cst : bf16) outs(%subview_8 : memref<32x32xf32, strided<[32, 1], offset: ?>>) {
              ^bb0(%in: bf16, %out: f32):
                %4 = arith.extf %in : bf16 to f32
                linalg.yield %4 : f32
              }
              linalg.generic {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%subview_6, %subview_7 : memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_8 : memref<32x32xf32, strided<[32, 1], offset: ?>>) {
              ^bb0(%in: bf16, %in_10: bf16, %out: f32):
                %4 = arith.extf %in : bf16 to f32
                %5 = arith.extf %in_10 : bf16 to f32
                %6 = arith.mulf %4, %5 : f32
                %7 = arith.addf %out, %6 : f32
                linalg.yield %7 : f32
              }
              linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel"]} ins(%subview_8 : memref<32x32xf32, strided<[32, 1], offset: ?>>) outs(%subview_9 : memref<32x32xbf16, strided<[32, 1], offset: ?>>) {
              ^bb0(%in: f32, %out: bf16):
                %4 = arith.truncf %in : f32 to bf16
                linalg.yield %4 : bf16
              }
            }
          }
          memref.dealloc %alloc : memref<?x4x32x32xf32>
          memref.copy %alloc_3, %subview_1 : memref<?x4x32x32xbf16> to memref<?x4x32x32xbf16, strided<[32768, 1024, 32, 1], offset: ?>>
          memref.dealloc %alloc_3 : memref<?x4x32x32xbf16>
        }
      }
    }
    return
  }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions