Description
As a follow up from iree-org/iree#18229 it seems like there is some dimension information that is not being captured correctly in IR and recovering that in the program is pretty involved. This is the IR after torch-finalizing-backend-type-conversion
util.func public @torch_jit$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c64_i64 = arith.constant 64 : i64
%c12_i64 = arith.constant 12 : i64
%c768_i64 = arith.constant 768 : i64
%c512_i64 = arith.constant 512 : i64
%c30522_i64 = arith.constant 30522 : i64
%cst = arith.constant dense<9.99999996E-13> : tensor<f32>
%cst_0 = arith.constant dense<2.000000e+00> : tensor<f32>
%c2_i64 = arith.constant 2 : i64
%cst_1 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___8> : tensor<1x512xi64>
%c0_i64 = arith.constant 0 : i64
%cst_2 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided__> : tensor<1x512xi64>
%cst_3 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___1> : tensor<30522x768xf32>
%cst_4 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___2> : tensor<512x768xf32>
%cst_5 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___3> : tensor<2x768xf32>
%cst_6 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___4> : tensor<768xf32>
%cst_7 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___5> : tensor<768xf32>
%cst_8 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___6> : tensor<768xf32>
%cst_9 = arith.constant dense_resource<__onnx_constant_not_found_possibly_due_to_being_elided___7> : tensor<768x768xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c-1 = arith.constant -1 : index
%c512 = arith.constant 512 : index
%cst_10 = arith.constant 0.000000e+00 : f32
%c1_i64 = arith.constant 1 : i64
%cst_11 = arith.constant 7.680000e+02 : f32
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
%2 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<?x?xi64>{%0, %1}
%dim = tensor.dim %2, %c1 : tensor<?x?xi64>
%3 = arith.cmpi slt, %dim, %c0 : index
%4 = arith.addi %dim, %c512 : index
%5 = arith.select %3, %4, %dim : index
%6 = arith.cmpi slt, %5, %c0 : index
%7 = arith.select %6, %c-1, %5 : index
%8 = arith.cmpi sgt, %7, %c512 : index
%9 = arith.select %8, %c512, %7 : index
%10 = arith.cmpi slt, %9, %c0 : index
%11 = arith.select %10, %c0, %9 : index
%extracted_slice = tensor.extract_slice %cst_1[0, 0] [1, %11] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
%extracted_slice_12 = tensor.extract_slice %cst_2[0, 0] [1, %11] [1, 1] : tensor<1x512xi64> to tensor<1x?xi64>
%dim_13 = tensor.dim %2, %c0 : tensor<?x?xi64>
%12 = tensor.empty(%dim_13, %dim) : tensor<?x?xi1>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<?x?xi64>) outs(%12 : tensor<?x?xi1>) {
^bb0(%in: i64, %out: i1):
%63 = arith.cmpi slt, %in, %c0_i64 : i64
linalg.yield %63 : i1
} -> tensor<?x?xi1>
%14 = tensor.empty(%dim_13, %dim) : tensor<?x?xi64>
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<?x?xi64>) outs(%14 : tensor<?x?xi64>) {
^bb0(%in: i64, %out: i64):
%63 = arith.addi %in, %c30522_i64 : i64
linalg.yield %63 : i64
} -> tensor<?x?xi64>
%16 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13, %15, %2 : tensor<?x?xi1>, tensor<?x?xi64>, tensor<?x?xi64>) outs(%14 : tensor<?x?xi64>) {
^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
%63 = arith.select %in, %in_26, %in_27 : i64
linalg.yield %63 : i64
} -> tensor<?x?xi64>
%17 = arith.index_cast %dim_13 : index to i64
%18 = arith.index_cast %dim : index to i64
%collapsed = tensor.collapse_shape %16 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
%19 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%dim_13, %dim]
%20 = tensor.empty(%19) : tensor<?x768xf32>
%21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%20 : tensor<?x768xf32>) {
^bb0(%in: i64, %out: f32):
%63 = arith.index_cast %in : i64 to index
%64 = linalg.index 1 : index
%extracted = tensor.extract %cst_3[%63, %64] : tensor<30522x768xf32>
linalg.yield %extracted : f32
} -> tensor<?x768xf32>
%from_elements = tensor.from_elements %17, %18, %c768_i64 : tensor<3xi64>
%reshape = tensor.reshape %21(%from_elements) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
%22 = tensor.empty(%11) : tensor<1x?xi1>
%23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x?xi64>) outs(%22 : tensor<1x?xi1>) {
^bb0(%in: i64, %out: i1):
%63 = arith.cmpi slt, %in, %c0_i64 : i64
linalg.yield %63 : i1
} -> tensor<1x?xi1>
%24 = tensor.empty(%11) : tensor<1x?xi64>
%25 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
^bb0(%in: i64, %out: i64):
%63 = arith.addi %in, %c2_i64 : i64
linalg.yield %63 : i64
} -> tensor<1x?xi64>
%26 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%23, %25, %extracted_slice : tensor<1x?xi1>, tensor<1x?xi64>, tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
%63 = arith.select %in, %in_26, %in_27 : i64
linalg.yield %63 : i64
} -> tensor<1x?xi64>
%27 = arith.index_cast %11 : index to i64
%collapsed_14 = tensor.collapse_shape %26 [[0, 1]] : tensor<1x?xi64> into tensor<?xi64>
%28 = tensor.empty(%11) : tensor<?x768xf32>
%29 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_14 : tensor<?xi64>) outs(%28 : tensor<?x768xf32>) {
^bb0(%in: i64, %out: f32):
%63 = arith.index_cast %in : i64 to index
%64 = linalg.index 1 : index
%extracted = tensor.extract %cst_5[%63, %64] : tensor<2x768xf32>
linalg.yield %extracted : f32
} -> tensor<?x768xf32>
%from_elements_15 = tensor.from_elements %c1_i64, %27, %c768_i64 : tensor<3xi64>
%reshape_16 = tensor.reshape %29(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
%30 = arith.index_cast %17 : i64 to index
%31 = arith.index_cast %18 : i64 to index
%32 = tensor.empty(%30, %31) : tensor<?x?x768xf32>
%33 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%reshape, %reshape_16 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.addf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
%34 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_12 : tensor<1x?xi64>) outs(%22 : tensor<1x?xi1>) {
^bb0(%in: i64, %out: i1):
%63 = arith.cmpi slt, %in, %c0_i64 : i64
linalg.yield %63 : i1
} -> tensor<1x?xi1>
%35 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice_12 : tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
^bb0(%in: i64, %out: i64):
%63 = arith.addi %in, %c512_i64 : i64
linalg.yield %63 : i64
} -> tensor<1x?xi64>
%36 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%34, %35, %extracted_slice_12 : tensor<1x?xi1>, tensor<1x?xi64>, tensor<1x?xi64>) outs(%24 : tensor<1x?xi64>) {
^bb0(%in: i1, %in_26: i64, %in_27: i64, %out: i64):
%63 = arith.select %in, %in_26, %in_27 : i64
linalg.yield %63 : i64
} -> tensor<1x?xi64>
%collapsed_17 = tensor.collapse_shape %36 [[0, 1]] : tensor<1x?xi64> into tensor<?xi64>
%37 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_17 : tensor<?xi64>) outs(%28 : tensor<?x768xf32>) {
^bb0(%in: i64, %out: f32):
%63 = arith.index_cast %in : i64 to index
%64 = linalg.index 1 : index
%extracted = tensor.extract %cst_4[%63, %64] : tensor<512x768xf32>
linalg.yield %extracted : f32
} -> tensor<?x768xf32>
%reshape_18 = tensor.reshape %37(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
%38 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%33, %reshape_18 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.addf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
%39 = tensor.empty(%30, %31) : tensor<?x?x1xf32>
%40 = linalg.fill ins(%cst_10 : f32) outs(%39 : tensor<?x?x1xf32>) -> tensor<?x?x1xf32>
%41 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%38 : tensor<?x?x768xf32>) outs(%40 : tensor<?x?x1xf32>) {
^bb0(%in: f32, %out: f32):
%63 = arith.addf %in, %out : f32
linalg.yield %63 : f32
} -> tensor<?x?x1xf32>
%42 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%41 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
^bb0(%in: f32, %out: f32):
%63 = arith.divf %in, %cst_11 : f32
linalg.yield %63 : f32
} -> tensor<?x?x1xf32>
%43 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%38, %42 : tensor<?x?x768xf32>, tensor<?x?x1xf32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.subf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
%44 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43, %cst_0 : tensor<?x?x768xf32>, tensor<f32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = math.powf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
%45 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%44 : tensor<?x?x768xf32>) outs(%40 : tensor<?x?x1xf32>) {
^bb0(%in: f32, %out: f32):
%63 = arith.addf %in, %out : f32
linalg.yield %63 : f32
} -> tensor<?x?x1xf32>
%46 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%45 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
^bb0(%in: f32, %out: f32):
%63 = arith.divf %in, %cst_11 : f32
linalg.yield %63 : f32
} -> tensor<?x?x1xf32>
%47 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%46, %cst : tensor<?x?x1xf32>, tensor<f32>) outs(%39 : tensor<?x?x1xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.addf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x1xf32>
%48 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%47 : tensor<?x?x1xf32>) outs(%39 : tensor<?x?x1xf32>) {
^bb0(%in: f32, %out: f32):
%63 = math.sqrt %in : f32
linalg.yield %63 : f32
} -> tensor<?x?x1xf32>
%49 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%43, %48 : tensor<?x?x768xf32>, tensor<?x?x1xf32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.divf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
%50 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%49, %cst_6 : tensor<?x?x768xf32>, tensor<768xf32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.mulf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
%51 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%50, %cst_7 : tensor<?x?x768xf32>, tensor<768xf32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.addf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
%52 = tensor.empty(%30) : tensor<?x768x768xf32>
%53 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_9 : tensor<768x768xf32>) outs(%52 : tensor<?x768x768xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x768x768xf32>
%54 = linalg.fill ins(%cst_10 : f32) outs(%32 : tensor<?x?x768xf32>) -> tensor<?x?x768xf32>
%55 = linalg.batch_matmul ins(%51, %53 : tensor<?x?x768xf32>, tensor<?x768x768xf32>) outs(%54 : tensor<?x?x768xf32>) -> tensor<?x?x768xf32>
%56 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_8, %55 : tensor<768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.addf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
%from_elements_19 = tensor.from_elements %17, %18, %c12_i64, %c64_i64 : tensor<4xi64>
%reshape_20 = tensor.reshape %56(%from_elements_19) : (tensor<?x?x768xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
%57 = arith.index_cast %17 : i64 to index
%58 = arith.index_cast %18 : i64 to index
%59 = tensor.empty(%57, %58) : tensor<?x12x?x64xf32>
%cast = tensor.cast %reshape_20 : tensor<?x?x?x?xf32> to tensor<?x?x12x64xf32>
%60 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cast : tensor<?x?x12x64xf32>) outs(%59 : tensor<?x12x?x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x12x?x64xf32>
%cast_21 = tensor.cast %60 : tensor<?x12x?x64xf32> to tensor<?x?x?x?xf32>
%61 = hal.tensor.barrier join(%cast_21 : tensor<?x?x?x?xf32>) => %arg2 : !hal.fence
%dim_22 = tensor.dim %61, %c0 : tensor<?x?x?x?xf32>
%dim_23 = tensor.dim %61, %c1 : tensor<?x?x?x?xf32>
%dim_24 = tensor.dim %61, %c2 : tensor<?x?x?x?xf32>
%dim_25 = tensor.dim %61, %c3 : tensor<?x?x?x?xf32>
%62 = hal.tensor.export %61 : tensor<?x?x?x?xf32>{%dim_22, %dim_23, %dim_24, %dim_25} -> !hal.buffer_view
util.return %62 : !hal.buffer_view
}
in particular this sequence of instructions
%0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
%1 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[1] : index
%2 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<?x?xi64>{%0, %1}
%dim = tensor.dim %2, %c1 : tensor<?x?xi64>
%dim_13 = tensor.dim %2, %c0 : tensor<?x?xi64>
%17 = arith.index_cast %dim_13 : index to i64
%18 = arith.index_cast %dim : index to i64
%from_elements = tensor.from_elements %17, %18, %c768_i64 : tensor<3xi64>
%reshape = tensor.reshape %21(%from_elements) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
%from_elements_15 = tensor.from_elements %c1_i64, %27, %c768_i64 : tensor<3xi64>
%reshape_16 = tensor.reshape %29(%from_elements_15) : (tensor<?x768xf32>, tensor<3xi64>) -> tensor<?x?x768xf32>
%30 = arith.index_cast %17 : i64 to index
%31 = arith.index_cast %18 : i64 to index
%32 = tensor.empty(%30, %31) : tensor<?x?x768xf32>
%33 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%reshape, %reshape_16 : tensor<?x?x768xf32>, tensor<?x?x768xf32>) outs(%32 : tensor<?x?x768xf32>) {
^bb0(%in: f32, %in_26: f32, %out: f32):
%63 = arith.addf %in, %in_26 : f32
linalg.yield %63 : f32
} -> tensor<?x?x768xf32>
The second tensor.reshape
operation has information that the outer most dim of the result (%reshape_16
) is of size 1. For consistency. %32
should also have an outer dimension of size 1. Looking through the IR.. that basically means %0
is of value 1. IIUC, torch should already know this information and it should have been materialized in the IR. This inconsistency between the same dimension being known statically in one-place and being dynamic in another causes issue downstream during compilation. The downstream compilation is being fixed, but this seems like something worth fixing in torch as well. If this is not possible, then we will have to build downstream of torch a way to constraint solve the dynamic dimensions for such cases (but I think Torch already has this).
Activity