Skip to content

Commit

Permalink
[mlir][linalg] Fix for invalid IR in eliminate_empty_tensors (#73513)
Browse files Browse the repository at this point in the history
The transform.structured.eliminate_empty_tensors can produce mis-typed
IR when traversing use-def chains past tensor reshaping operations for
sharing candidates. This results in Linalg operations whose output types
do not match their 'outs' arguments.

This patch filters out candidate tensor.empty operations when their
types do not match the candidate input operand.
  • Loading branch information
sabauma authored Jan 1, 2024
1 parent f33245a commit 6b65d79
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
in->get(), /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
[&](Value val) {
return val.getDefiningOp<tensor::EmptyOp>() &&
val.getType() == in->get().getType();
},
config);
if (emptyTensors.empty())
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,89 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

#map = affine_map<(d0) -> (d0)>

// This test is intended to check that the produced IR does not contain any
// type errors from sharing empty tensor operations with different types.
// The verifiers are sufficient to lock down the intended behavior.

// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<56xf32>
{
%init0 = tensor.empty() : tensor<56xf32>
%init1 = tensor.empty() : tensor<56x1xf32>

%filled_tensor = linalg.fill
ins(%fill_value : f32)
outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>

// The collapse shape alters the tensor rank, so the %init1 tensor.empty cannot be
// pushed into the output of the linalg.generic.
%reshaped_tensor = tensor.collapse_shape %filled_tensor [[0, 1]]
: tensor<56x1xf32> into tensor<56xf32>

%bias = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]
} ins(%reshaped_tensor : tensor<56xf32>)
outs(%init0 : tensor<56xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<56xf32>

return %bias : tensor<56xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.eliminate_empty_tensors %0 : !transform.any_op
transform.yield
}
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>

// This test is intended to check that the produced IR does not contain any
// type errors from sharing empty tensor operations with different types.
// The verifiers are sufficient to lock down the intended behavior.

// CHECK-LABEL: func.func @collapse_cast_prevents_reuse(
func.func @collapse_cast_prevents_reuse(%fill_value: f32) -> tensor<56x?xf32>
{
%c1 = arith.constant 1 : index
%init0 = tensor.empty(%c1) : tensor<56x?xf32>
%init1 = tensor.empty() : tensor<56x1xf32>

%filled_tensor = linalg.fill
ins(%fill_value : f32)
outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>

// The cast alters the number of dynamic dims, so the %init1 tensor.empty cannot be
// pushed into the output of the linalg.generic.
%cast = tensor.cast %filled_tensor : tensor<56x1xf32> to tensor<56x?xf32>

%bias = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel"]
} ins(%cast : tensor<56x?xf32>)
outs(%init0 : tensor<56x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<56x?xf32>

return %bias : tensor<56x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.eliminate_empty_tensors %0 : !transform.any_op
transform.yield
}
}

0 comments on commit 6b65d79

Please sign in to comment.