Skip to content

Commit c6472f5

Browse files
authored
[mlir][sparse] More allocate -> empty tensor migration (#66720)
This also allows tensor.empty in the "conversion" path of the sparse compiler, further paving the way to deprecate the bufferization.allocated_tensor() op.
1 parent bdb5c9c commit c6472f5

File tree

9 files changed

+110
-79
lines changed

9 files changed

+110
-79
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
830830
};
831831

832832
/// Sparse conversion rule for the alloc operator.
833+
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
833834
class SparseTensorAllocConverter
834835
: public OpConversionPattern<bufferization::AllocTensorOp> {
835836
public:
@@ -864,6 +865,37 @@ class SparseTensorAllocConverter
864865
}
865866
};
866867

868+
/// Sparse conversion rule for the empty tensor.
869+
class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
870+
public:
871+
using OpConversionPattern::OpConversionPattern;
872+
LogicalResult
873+
matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
874+
ConversionPatternRewriter &rewriter) const override {
875+
Location loc = op.getLoc();
876+
const auto stt = getSparseTensorType(op);
877+
if (!stt.hasEncoding())
878+
return failure();
879+
// Gather all dimension sizes as SSA values.
880+
const Dimension dimRank = stt.getDimRank();
881+
SmallVector<Value> dimSizes;
882+
dimSizes.reserve(dimRank);
883+
auto shape = op.getType().getShape();
884+
unsigned operandCtr = 0;
885+
for (Dimension d = 0; d < dimRank; ++d) {
886+
dimSizes.push_back(stt.isDynamicDim(d)
887+
? adaptor.getOperands()[operandCtr++]
888+
: constantIndex(rewriter, loc, shape[d]));
889+
}
890+
// Generate the call to construct empty tensor. The sizes are
891+
// explicitly defined by the arguments to the alloc operator.
892+
rewriter.replaceOp(op, NewCallParams(rewriter, loc)
893+
.genBuffers(stt, dimSizes)
894+
.genNewCall(Action::kEmpty));
895+
return success();
896+
}
897+
};
898+
867899
/// Sparse conversion rule for the convert operator.
868900
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
869901
public:
@@ -1503,19 +1535,19 @@ mlir::SparseTensorTypeToPtrConverter::SparseTensorTypeToPtrConverter() {
15031535
void mlir::populateSparseTensorConversionPatterns(
15041536
TypeConverter &typeConverter, RewritePatternSet &patterns,
15051537
const SparseTensorConversionOptions &options) {
1506-
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
1507-
SparseCastConverter, SparseTensorNewConverter,
1508-
SparseReshapeConverter<tensor::ExpandShapeOp>,
1509-
SparseReshapeConverter<tensor::CollapseShapeOp>,
1510-
SparseTensorConcatConverter, SparseTensorAllocConverter,
1511-
SparseTensorDeallocConverter, SparseTensorToPositionsConverter,
1512-
SparseTensorToCoordinatesConverter,
1513-
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
1514-
SparseTensorLoadConverter, SparseTensorInsertConverter,
1515-
SparseTensorExpandConverter, SparseTensorCompressConverter,
1516-
SparseTensorOutConverter, SparseTensorPackConverter>(
1517-
typeConverter, patterns.getContext());
1518-
1538+
patterns
1539+
.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
1540+
SparseCastConverter, SparseTensorNewConverter,
1541+
SparseReshapeConverter<tensor::ExpandShapeOp>,
1542+
SparseReshapeConverter<tensor::CollapseShapeOp>,
1543+
SparseTensorConcatConverter, SparseTensorAllocConverter,
1544+
SparseTensorEmptyConverter, SparseTensorDeallocConverter,
1545+
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
1546+
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
1547+
SparseTensorLoadConverter, SparseTensorInsertConverter,
1548+
SparseTensorExpandConverter, SparseTensorCompressConverter,
1549+
SparseTensorOutConverter, SparseTensorPackConverter>(
1550+
typeConverter, patterns.getContext());
15191551
patterns.add<SparseTensorConvertConverter>(typeConverter,
15201552
patterns.getContext(), options);
15211553
}

mlir/test/Dialect/SparseTensor/constant_index_map.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 77 : index
1414
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
1515
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
16-
// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<77xi1, #{{.*}}>
16+
// CHECK-DAG: %[[VAL_5:.*]] = tensor.empty() : tensor<77xi1, #{{.*}}>
1717
// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<1x77xi1>
1818
// CHECK-DAG: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<1x77xi1>
1919
// CHECK: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_5]]) -> (tensor<77xi1, #{{.*}}>) {
@@ -27,7 +27,7 @@
2727
// CHECK: return %[[VAL_15]] : tensor<77xi1, #{{.*}}>
2828
// CHECK: }
2929
func.func @main(%arg0: tensor<1x77xi1>, %arg1: tensor<1x77xi1>) -> tensor<77xi1, #SpVec> {
30-
%0 = bufferization.alloc_tensor() : tensor<77xi1, #SpVec>
30+
%0 = tensor.empty() : tensor<77xi1, #SpVec>
3131
%1 = linalg.generic {
3232
indexing_maps = [#map1, #map1, #map2],
3333
iterator_types = ["parallel"]}

mlir/test/Dialect/SparseTensor/sparse_affine.mlir

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
21
// RUN: mlir-opt %s -sparsification | FileCheck %s
32

43
#SpVec = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
@@ -17,9 +16,9 @@
1716
}
1817

1918
// CHECK-LABEL: func @mul_inv_dense1d(
20-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
21-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32>,
22-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
19+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
20+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32>,
21+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xf32>) -> tensor<32xf32> {
2322
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
2423
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 3 : index
2524
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
@@ -57,13 +56,13 @@ func.func @mul_inv_dense1d(%arga: tensor<32xf32, #SpVec>,
5756
}
5857

5958
// CHECK-LABEL: func.func @mul_inv_sparse1d(
60-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
61-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>>)
59+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
60+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>>)
6261
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
6362
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
6463
// CHECK: %[[VAL_4:.*]] = arith.constant 3 : index
6564
// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
66-
// CHECK: %[[VAL_6:.*]] = bufferization.alloc_tensor() : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
65+
// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
6766
// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
6867
// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
6968
// CHECK: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -95,7 +94,7 @@ func.func @mul_inv_dense1d(%arga: tensor<32xf32, #SpVec>,
9594
// CHECK: return %[[VAL_32]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
9695
func.func @mul_inv_sparse1d(%arga: tensor<32xf32, #SpVec>,
9796
%argb: tensor<4xf32, #SpVec>) -> tensor<32xf32, #SpVec> {
98-
%argx = bufferization.alloc_tensor() : tensor<32xf32, #SpVec>
97+
%argx = tensor.empty() : tensor<32xf32, #SpVec>
9998
%0 = linalg.generic #trait1
10099
ins(%arga, %argb: tensor<32xf32, #SpVec>, tensor<4xf32, #SpVec>)
101100
outs(%argx: tensor<32xf32, #SpVec>) {
@@ -109,13 +108,13 @@ func.func @mul_inv_sparse1d(%arga: tensor<32xf32, #SpVec>,
109108

110109

111110
// CHECK-LABEL: func.func @mul_inv_enc_dense1d(
112-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
113-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> {
111+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>,
112+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> {
114113
// CHECK: %[[VAL_2:.*]] = arith.constant 32 : index
115114
// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index
116115
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
117116
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
118-
// CHECK: %[[VAL_6:.*]] = bufferization.alloc_tensor() : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
117+
// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>>
119118
// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
120119
// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<4xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
121120
// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_6]] : tensor<32xf32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf32>
@@ -132,7 +131,7 @@ func.func @mul_inv_sparse1d(%arga: tensor<32xf32, #SpVec>,
132131
// CHECK: }
133132
func.func @mul_inv_enc_dense1d(%arga: tensor<32xf32, #EncDenseVec>,
134133
%argb: tensor<4xf32, #EncDenseVec>) -> tensor<32xf32, #EncDenseVec> {
135-
%argx = bufferization.alloc_tensor() : tensor<32xf32, #EncDenseVec>
134+
%argx = tensor.empty() : tensor<32xf32, #EncDenseVec>
136135
%0 = linalg.generic #trait1
137136
ins(%arga, %argb: tensor<32xf32, #EncDenseVec>, tensor<4xf32, #EncDenseVec>)
138137
outs(%argx: tensor<32xf32, #EncDenseVec>) {
@@ -155,9 +154,9 @@ func.func @mul_inv_enc_dense1d(%arga: tensor<32xf32, #EncDenseVec>,
155154
}
156155

157156
// CHECK-LABEL: func @and_affine_dense1d(
158-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
159-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<34xi32>,
160-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi32>) -> tensor<32xi32> {
157+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
158+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<34xi32>,
159+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi32>) -> tensor<32xi32> {
161160
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32
162161
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
163162
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
@@ -195,12 +194,12 @@ func.func @and_affine_dense1d(%arga: tensor<32xi32, #SpVec>,
195194
}
196195

197196
// CHECK-LABEL: func.func @and_affine_sparse1d(
198-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
199-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<34xi32, #sparse_tensor.encoding<{{{.*}}}>>)
197+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>,
198+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<34xi32, #sparse_tensor.encoding<{{{.*}}}>>)
200199
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
201200
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
202201
// CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
203-
// CHECK: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
202+
// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
204203
// CHECK: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
205204
// CHECK: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
206205
// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xi32>
@@ -234,7 +233,7 @@ func.func @and_affine_dense1d(%arga: tensor<32xi32, #SpVec>,
234233
// CHECK: return %[[VAL_33]] : tensor<32xi32, #sparse_tensor.encoding<{{{.*}}}>>
235234
func.func @and_affine_sparse1d(%arga: tensor<32xi32, #SpVec>,
236235
%argb: tensor<34xi32, #SpVec>) -> tensor<32xi32, #SpVec> {
237-
%argx = bufferization.alloc_tensor() : tensor<32xi32, #SpVec>
236+
%argx = tensor.empty() : tensor<32xi32, #SpVec>
238237
%0 = linalg.generic #trait2
239238
ins(%arga, %argb: tensor<32xi32, #SpVec>, tensor<34xi32, #SpVec>)
240239
outs(%argx: tensor<32xi32, #SpVec>) {
@@ -256,9 +255,9 @@ func.func @and_affine_sparse1d(%arga: tensor<32xi32, #SpVec>,
256255
}
257256

258257
// CHECK-LABEL: func @mul_affine_dense2d(
259-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
260-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<34x19xf64>,
261-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
258+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
259+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<34x19xf64>,
260+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
262261
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index
263262
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 32 : index
264263
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index
@@ -304,8 +303,8 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
304303

305304

306305
// CHECK-LABEL: func.func @mul_affine_sparse2d(
307-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
308-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<34x19xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> {
306+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
307+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<34x19xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> {
309308
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 32 : index
310309
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
311310
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
@@ -314,7 +313,7 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
314313
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index
315314
// CHECK-DAG: %[[VAL_TRUE:.*]] = arith.constant true
316315
// CHECK-DAG: %[[VAL_FALSE:.*]] = arith.constant false
317-
// CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
316+
// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
318317
// CHECK: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
319318
// CHECK: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
320319
// CHECK: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
@@ -360,7 +359,7 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
360359
// CHECK: return %[[VAL_45]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
361360
func.func @mul_affine_sparse2d(%arga: tensor<32x16xf64, #CSR>,
362361
%argb: tensor<34x19xf64, #CSR>) -> tensor<32x16xf64, #CSR> {
363-
%argx = bufferization.alloc_tensor() : tensor<32x16xf64, #CSR>
362+
%argx = tensor.empty() : tensor<32x16xf64, #CSR>
364363
%0 = linalg.generic #trait3
365364
ins(%arga, %argb: tensor<32x16xf64, #CSR>, tensor<34x19xf64, #CSR>)
366365
outs(%argx: tensor<32x16xf64, #CSR>) {
@@ -383,9 +382,9 @@ func.func @mul_affine_sparse2d(%arga: tensor<32x16xf64, #CSR>,
383382
}
384383

385384
// CHECK-LABEL: func.func @mul_affine_dense_dim_2d(
386-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<34x16xf64, #sparse_tensor.encoding
387-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
388-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
385+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<34x16xf64, #sparse_tensor.encoding
386+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
387+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
389388
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 19 : index
390389
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
391390
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
@@ -447,9 +446,9 @@ func.func @mul_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>,
447446
}
448447

449448
// CHECK-LABEL: func.func @mul_const_affine_dense_dim_2d(
450-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<34x16xf64,
451-
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
452-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
449+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<34x16xf64,
450+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>,
451+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> {
453452
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 19 : index
454453
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
455454
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index

mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
1717
// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
1818
// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
19-
// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor()
19+
// CHECK: %[[TMP_0:.*]] = tensor.empty()
2020
// CHECK: %[[TMP_1:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 0 : index}
2121
// CHECK: %[[TMP_2:.*]] = sparse_tensor.coordinates %[[TMP_arg0]] {level = 0 : index}
2222
// CHECK: %[[TMP_3:.*]] = sparse_tensor.positions %[[TMP_arg0]] {level = 1 : index}
@@ -44,7 +44,7 @@
4444
// CHECK: return %[[TMP_8]]
4545
module @func_sparse {
4646
func.func public @main(%arg0: tensor<4x5xi32, #DCSR>) -> tensor<4x3x5xi32, #SparseTensor> {
47-
%0 = bufferization.alloc_tensor() : tensor<4x3x5xi32, #SparseTensor>
47+
%0 = tensor.empty() : tensor<4x3x5xi32, #SparseTensor>
4848
%1 = linalg.generic #trait
4949
ins(%arg0 : tensor<4x5xi32, #DCSR>) outs(%0 : tensor<4x3x5xi32, #SparseTensor>) {
5050
^bb0(%in: i32, %out: i32):

mlir/test/Dialect/SparseTensor/sparse_expand.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
6868
%c0 = arith.constant 0 : index
6969
%n = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSC>
70-
%v = bufferization.alloc_tensor(%n) : tensor<?xf64, #SV>
70+
%v = tensor.empty(%n) : tensor<?xf64, #SV>
7171
%0 = linalg.generic #rowsum
7272
ins(%arga: tensor<?x?xf64, #DCSC>)
7373
outs(%v: tensor<?xf64, #SV>) {
@@ -119,7 +119,7 @@ func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
119119
//
120120
func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
121121
%B: tensor<2x4xf64, #CSR>) -> tensor<8x4xf64, #CSR> {
122-
%C = bufferization.alloc_tensor() : tensor<8x4xf64, #CSR>
122+
%C = tensor.empty() : tensor<8x4xf64, #CSR>
123123
%D = linalg.matmul
124124
ins(%A, %B: tensor<8x2xf64, #CSR>, tensor<2x4xf64, #CSR>)
125125
outs(%C: tensor<8x4xf64, #CSR>) -> tensor<8x4xf64, #CSR>
@@ -167,7 +167,7 @@ func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
167167
//
168168
func.func @matmul2(%A: tensor<8x2xf64, #CSC>,
169169
%B: tensor<2x4xf64, #CSC>) -> tensor<8x4xf64, #CSC> {
170-
%C = bufferization.alloc_tensor() : tensor<8x4xf64, #CSC>
170+
%C = tensor.empty() : tensor<8x4xf64, #CSC>
171171
%D = linalg.matmul
172172
ins(%A, %B: tensor<8x2xf64, #CSC>, tensor<2x4xf64, #CSC>)
173173
outs(%C: tensor<8x4xf64, #CSC>) -> tensor<8x4xf64, #CSC>

0 commit comments

Comments
 (0)