Skip to content

Missing dimensionality information in torch #3651

Closed
@MaheshRavishankar

Description

As a follow up from iree-org/iree#18229 it seems like there is some dimension information that is not being captured correctly in IR and recovering that in the program is pretty involved. This is the IR after torch-finalizing-backend-type-conversion

util.func public @torch_jit$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
  %c3 = arith.constant 3 : index
  %c2 = arith.constant 2 : index
  %c64_i64 = arith.constant 64 : i64
  %c12_i64 = arith.constant 12 : i64
  %c768_i64 = arith.constant 768 : i64
  %c512_i64 = arith.constant 512 : i64
  %c30522_i64 = arith.constant 30522 : i64
  %cst = arith.constant dense<9.99999996E-13> : tensor<f32>
  %cst_0 = arith.constant dense<2.000000e+00> : tensor<f32>
  %c2_i64 = arith.constant 2 : i64
  %cst_1 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___8> : tensor<1x512xi64>
  %c0_i64 = arith.constant 0 : i64
  %cst_2 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided__> : tensor<1x512xi64>
  %cst_3 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___1> : tensor<30522x768xf32>
  %cst_4 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___2> : tensor<512x768xf32>
  %cst_5 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___3> : tensor<2x768xf32>
  %cst_6 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___4> : tensor<768xf32>
  %cst_7 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___5> : tensor<768xf32>
  %cst_8 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___6> : tensor<768xf32>
  %cst_9 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___7> : tensor<768x768xf32>
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c-1 = arith.constant -1 : index
  %c512 = arith.constant 512 : index
  %cst_10 = arith.constant 0.000000e+00 : f32
  %c1_i64 = arith.constant 1 : i64
  %cst_11 = arith.constant 7.680000e+02 : f32
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<?x?xi64>{%0, %1}
  %dim = tensor.dim %2, %c1 : tensor<?x?xi64>
  %3 = arith.cmpi slt, %dim, %c0 : index
  %4 = arith.addi %dim, %c512 : index
  %5 = arith.select %3, %4, %dim : index
  %6 = arith.cmpi slt, %5, %c0 : index
  %7 = arith.select %6, %c-1, %5 : index
  %8 = arith.cmpi sgt, %7, %c512 : index
  %9 = arith.select %8, %c512, %7 : index
  %10 = arith.cmpi slt, %9, %c0 : index
  %11 = arith.select %10, %c0, %9 : index
  %extracted_slice = tensor.extract_slice %cst_1[0, 0] [1, %11] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
  %extracted_slice_12 = tensor.extract_slice %cst_2[0, 0] [1, %11] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
  %dim_13 = tensor.dim %2, %c0 : tensor<?x?xi64>
  %12 = tensor.empty(%dim_13, %dim) : tensor<?x?xi1>
  %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<?x?xi64>) outs(%12 : tensor<?x?xi1>) {
  ^bb0(%in: i64, %out: i1):
    %63 = arith.cmpi slt, %in, %c0_i64 : i64
    linalg.yield %63 : i1
  } -> tensor<?x?xi1>
  %14 = tensor.empty(%dim_13, %dim) : tensor<?x?xi64>
  %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<?x?xi64>) outs(%14 : tensor<?x?xi64>) {
  ^bb0(%in: i64, %out: i64):
    %63 = arith.addi %in, %c30522_i64 : i64
    linalg.yield %63 : i64
  } -> tensor<?x?xi64>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13, %15, %2 : tensor<?x?xi1>, tensor<?x?xi64>, tensor<?x?xi64>) outs(%14 : tensor<?x?xi64>) {
  ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
    %63 = arith.select %in, %in_26, %in_27 : i64
    linalg.yield %63 : i64
  } -> tensor<?x?xi64>
  %17 = arith.index_cast %dim_13 : index to i64
  %18 = arith.index_cast %dim : index to i64
  %collapsed = tensor.collapse_shape %16 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
  %19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%dim_13, %dim]
  %20 = tensor.empty(%19) : tensor<?x768xf32>
  %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%20 : tensor<?x768xf32>) {
  ^bb0(%in: i64, %out: f32):
    %63 = arith.index_cast %in : i64 to index
    %64 = linalg.index 1 : index
    %extracted = tensor.extract %cst_3[%63, %64] : tensor<30522x768xf32>
    linalg.yield %extracted : f32
  } -> tensor<?x768xf32>
  %from_elements = tensor.from_elements %17, %18, %c768_i64 : tensor<3xi64>
  %reshape = tensor.reshape %21(%from_elements) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
  %22 = tensor.empty(%11) : tensor<1x?xi1>
  %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x?xi64>) outs(%22 : tensor<1x?xi1>) {
  ^bb0(%in: i64, %out: i1):
    %63 = arith.cmpi slt, %in, %c0_i64 : i64
    linalg.yield %63 : i1
  } -> tensor<1x?xi1>
  %24 = tensor.empty(%11) : tensor<1x?xi64>
  %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
  ^bb0(%in: i64, %out: i64):
    %63 = arith.addi %in, %c2_i64 : i64
    linalg.yield %63 : i64
  } -> tensor<1x?xi64>
  %26 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%23, %25, %extracted_slice : tensor<1x?xi1>, tensor<1x?xi64>, tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
  ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
    %63 = arith.select %in, %in_26, %in_27 : i64
    linalg.yield %63 : i64
  } -> tensor<1x?xi64>
  %27 = arith.index_cast %11 : index to i64
  %collapsed_14 = tensor.collapse_shape %26 [[0, 1]] : tensor<1x?xi64> into tensor<?xi64>
  %28 = tensor.empty(%11) : tensor<?x768xf32>
  %29 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_14 : tensor<?xi64>) outs(%28 : tensor<?x768xf32>) {
  ^bb0(%in: i64, %out: f32):
    %63 = arith.index_cast %in : i64 to index
    %64 = linalg.index 1 : index
    %extracted = tensor.extract %cst_5[%63, %64] : tensor<2x768xf32>
    linalg.yield %extracted : f32
  } -> tensor<?x768xf32>
  %from_elements_15 = tensor.from_elements %c1_i64, %27, %c768_i64 : tensor<3xi64>
  %reshape_16 = tensor.reshape %29(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
  %30 = arith.index_cast %17 : i64 to index
  %31 = arith.index_cast %18 : i64 to index
  %32 = tensor.empty(%30, %31) : tensor<?x?x768xf32>
  %33 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%reshape, %reshape_16 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %34 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_12 : tensor<1x?xi64>) outs(%22 : tensor<1x?xi1>) {
  ^bb0(%in: i64, %out: i1):
    %63 = arith.cmpi slt, %in, %c0_i64 : i64
    linalg.yield %63 : i1
  } -> tensor<1x?xi1>
  %35 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_12 : tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
  ^bb0(%in: i64, %out: i64):
    %63 = arith.addi %in, %c512_i64 : i64
    linalg.yield %63 : i64
  } -> tensor<1x?xi64>
  %36 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%34, %35, %extracted_slice_12 : tensor<1x?xi1>, tensor<1x?xi64>, tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
  ^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
    %63 = arith.select %in, %in_26, %in_27 : i64
    linalg.yield %63 : i64
  } -> tensor<1x?xi64>
  %collapsed_17 = tensor.collapse_shape %36 [[0, 1]] : tensor<1x?xi64> into tensor<?xi64>
  %37 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_17 : tensor<?xi64>) outs(%28 : tensor<?x768xf32>) {
  ^bb0(%in: i64, %out: f32):
    %63 = arith.index_cast %in : i64 to index
    %64 = linalg.index 1 : index
    %extracted = tensor.extract %cst_4[%63, %64] : tensor<512x768xf32>
    linalg.yield %extracted : f32
  } -> tensor<?x768xf32>
  %reshape_18 = tensor.reshape %37(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
  %38 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%33, %reshape_18 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %39 = tensor.empty(%30, %31) : tensor<?x?x1xf32>
  %40 = linalg.fill ins(%cst_10 : f32) outs(%39 : tensor<?x?x1xf32>) -> tensor<?x?x1xf32>
  %41 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%38 : tensor<?x?x768xf32>) outs(%40 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = arith.addf %in, %out : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %42 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%41 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = arith.divf %in, %cst_11 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %43 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%38, %42 : tensor<?x?x768xf32>, tensor<?x?x1xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.subf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %44 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43, %cst_0 : tensor<?x?x768xf32>, tensor<f32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = math.powf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %45 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%44 : tensor<?x?x768xf32>) outs(%40 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = arith.addf %in, %out : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %46 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%45 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = arith.divf %in, %cst_11 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %47 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%46, %cst : tensor<?x?x1xf32>, tensor<f32>) outs(%39 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %48 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%47 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %63 = math.sqrt %in : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x1xf32>
  %49 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43, %48 : tensor<?x?x768xf32>, tensor<?x?x1xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.divf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %50 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%49, %cst_6 : tensor<?x?x768xf32>, tensor<768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.mulf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %51 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%50, %cst_7 : tensor<?x?x768xf32>, tensor<768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %52 = tensor.empty(%30) : tensor<?x768x768xf32>
  %53 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_9 : tensor<768x768xf32>) outs(%52 : tensor<?x768x768xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<?x768x768xf32>
  %54 = linalg.fill ins(%cst_10 : f32) outs(%32 : tensor<?x?x768xf32>) -> tensor<?x?x768xf32>
  %55 = linalg.batch_matmul ins(%51, %53 : tensor<?x?x768xf32>, tensor<?x768x768xf32>) outs(%54 : tensor<?x?x768xf32>) -> tensor<?x?x768xf32>
  %56 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_8, %55 : tensor<768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>
  %from_elements_19 = tensor.from_elements %17, %18, %c12_i64, %c64_i64 : tensor<4xi64>
  %reshape_20 = tensor.reshape %56(%from_elements_19) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
  %57 = arith.index_cast %17 : i64 to index
  %58 = arith.index_cast %18 : i64 to index
  %59 = tensor.empty(%57, %58) : tensor<?x12x?x64xf32>
  %cast = tensor.cast %reshape_20 : tensor<?x?x?x?xf32> to tensor<?x?x12x64xf32>
  %60 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast : tensor<?x?x12x64xf32>) outs(%59 : tensor<?x12x?x64xf32>) {
  ^bb0(%in: f32, %out: f32):
    linalg.yield %in : f32
  } -> tensor<?x12x?x64xf32>
  %cast_21 = tensor.cast %60 : tensor<?x12x?x64xf32> to tensor<?x?x?x?xf32>
  %61 = hal.tensor.barrier join(%cast_21 : tensor<?x?x?x?xf32>) => %arg2 : !hal.fence
  %dim_22 = tensor.dim %61, %c0 : tensor<?x?x?x?xf32>
  %dim_23 = tensor.dim %61, %c1 : tensor<?x?x?x?xf32>
  %dim_24 = tensor.dim %61, %c2 : tensor<?x?x?x?xf32>
  %dim_25 = tensor.dim %61, %c3 : tensor<?x?x?x?xf32>
  %62 = hal.tensor.export %61 : tensor<?x?x?x?xf32>{%dim_22, %dim_23, %dim_24, %dim_25} -> !hal.buffer_view
  util.return %62 : !hal.buffer_view
}

in particular this sequence of instructions

  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
  %2 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<?x?xi64>{%0, %1}
  %dim = tensor.dim %2, %c1 : tensor<?x?xi64>
  %dim_13 = tensor.dim %2, %c0 : tensor<?x?xi64>
  %17 = arith.index_cast %dim_13 : index to i64
  %18 = arith.index_cast %dim : index to i64
  %from_elements = tensor.from_elements %17, %18, %c768_i64 : tensor<3xi64>
  %reshape = tensor.reshape %21(%from_elements) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
 %from_elements_15 = tensor.from_elements %c1_i64, %27, %c768_i64 : tensor<3xi64>
  %reshape_16 = tensor.reshape %29(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
  %30 = arith.index_cast %17 : i64 to index
  %31 = arith.index_cast %18 : i64 to index
  %32 = tensor.empty(%30, %31) : tensor<?x?x768xf32>
  %33 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%reshape, %reshape_16 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
  ^bb0(%in: f32, %in_26: f32, %out: f32):
    %63 = arith.addf %in, %in_26 : f32
    linalg.yield %63 : f32
  } -> tensor<?x?x768xf32>

The second tensor.reshape operation has information that the outer most dim of the result (%reshape_16) is of size 1. For consistency. %32 should also have an outer dimension of size 1. Looking through the IR.. that basically means %0 is of value 1. IIUC, torch should already know this information and it should have been materialized in the IR. This inconsistency between the same dimension being known statically in one-place and being dynamic in another causes issue downstream during compilation. The downstream compilation is being fixed, but this seems like something worth fixing in torch as well. If this is not possible, then we will have to build downstream of torch a way to constraint solve the dynamic dimensions for such cases (but I think Torch already has this).

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions