Skip to content

[mlir][sparse] force a properly sized view on pos/crd/val under codegen #91288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,10 +1050,14 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested position access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
auto mem = desc.getPosMemRef(lvl);
auto size = desc.getPosMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};
Expand All @@ -1068,12 +1072,17 @@ class SparseToCoordinatesConverter
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = op.getLevel();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(
op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));

auto mem = desc.getCrdMemRefOrView(rewriter, loc, lvl);
if (lvl < getSparseTensorType(op.getTensor()).getAoSCOOStart()) {
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
mem = genSliceToSize(rewriter, loc, mem, size);
}
rewriter.replaceOp(op, mem);
return success();
}
};
Expand All @@ -1088,11 +1097,14 @@ class SparseToCoordinatesBufferConverter
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
Level lvl = getSparseTensorType(op.getTensor()).getAoSCOOStart();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getAOSMemRef());

auto mem = desc.getAOSMemRef();
auto size = desc.getCrdMemSize(rewriter, loc, lvl);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};
Expand All @@ -1106,10 +1118,13 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested values access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
// The view is restricted to the actual size to ensure clients
// of this operation truly observe size, not capacity!
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getValMemRef());
auto mem = desc.getValMemRef();
auto size = desc.getValMemSize(rewriter, loc);
rewriter.replaceOp(op, genSliceToSize(rewriter, loc, mem, size));
return success();
}
};
Expand Down
69 changes: 37 additions & 32 deletions mlir/test/Dialect/SparseTensor/binary_valued.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@
//
// Make sure X += A * A => X += 1 in single loop.
//
//
// CHECK-LABEL: func.func @sum_squares(
// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> {
// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> {
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
Expand All @@ -40,23 +39,25 @@
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
// CHECK: linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_7]] : index
// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_14]]) -> (f32) {
// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_15]] : index
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_19]]] : memref<?xindex>
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_21]]] : memref<?xindex>
// CHECK: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_18]]) -> (f32) {
// CHECK: %[[VAL_26:.*]] = arith.addf %[[VAL_25]], %[[VAL_4]] : f32
// CHECK: scf.yield %[[VAL_26]] : f32
// CHECK: %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex>
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_7]] : index
// CHECK: %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_16]]) -> (f32) {
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<?xindex>
// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_5]] : index
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
// CHECK: %[[VAL_25:.*]] = scf.for %[[VAL_26:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] iter_args(%[[VAL_27:.*]] = %[[VAL_20]]) -> (f32) {
// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_4]] : f32
// CHECK: scf.yield %[[VAL_28]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: scf.yield %[[VAL_23]] : f32
// CHECK: scf.yield %[[VAL_25]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: scf.yield %[[VAL_16]] : f32
// CHECK: scf.yield %[[VAL_18]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
// CHECK: memref.store %[[VAL_14]], %[[VAL_10]][] : memref<f32>
// CHECK: return %[[VAL_10]] : memref<f32>
// CHECK: }
//
Expand Down Expand Up @@ -99,25 +100,29 @@ func.func @sum_squares(%a: tensor<2x3x8xf32, #Sparse>) -> tensor<f32> {
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
// CHECK: linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_7]] : index
// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_14]]) -> (f32) {
// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_15]] : index
// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_19]]] : memref<?xindex>
// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_21]]] : memref<?xindex>
// CHECK: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_18]]) -> (f32) {
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_24]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_13]], %[[VAL_17]], %[[VAL_26]]] : memref<2x3x8xf32>
// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_25]] : f32
// CHECK: scf.yield %[[VAL_28]] : f32
// CHECK: %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex>
// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_13]]] [1] : memref<?xindex> to memref<?xindex>
// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) {
// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_7]] : index
// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_18]]) -> (f32) {
// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_19]] : index
// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_5]] : index
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_25]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] iter_args(%[[VAL_29:.*]] = %[[VAL_22]]) -> (f32) {
// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_28]]] : memref<?xindex>
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_17]], %[[VAL_21]], %[[VAL_30]]] : memref<2x3x8xf32>
// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_31]], %[[VAL_29]] : f32
// CHECK: scf.yield %[[VAL_32]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: scf.yield %[[VAL_23]] : f32
// CHECK: scf.yield %[[VAL_27]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: scf.yield %[[VAL_16]] : f32
// CHECK: scf.yield %[[VAL_20]] : f32
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: memref.store %[[VAL_12]], %[[VAL_10]][] : memref<f32>
// CHECK: memref.store %[[VAL_16]], %[[VAL_10]][] : memref<f32>
// CHECK: return %[[VAL_10]] : memref<f32>
// CHECK: }
//
Expand Down
23 changes: 16 additions & 7 deletions mlir/test/Dialect/SparseTensor/codegen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: return %[[A2]] : memref<?xi32>
// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] pos_mem_sz at 1
// CHECK: %[[V:.*]] = memref.subview %[[A2]][0] [%[[S]]] [1]
// CHECK: return %[[V]] : memref<?xi32>
func.func @sparse_positions_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
%0 = sparse_tensor.positions %arg0 { level = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi32>
return %0 : memref<?xi32>
Expand All @@ -279,7 +281,9 @@ func.func @sparse_positions_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32>
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: return %[[A3]] : memref<?xi64>
// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
// CHECK: %[[V:.*]] = memref.subview %[[A3]][0] [%[[S]]] [1]
// CHECK: return %[[V]] : memref<?xi64>
func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
%0 = sparse_tensor.coordinates %arg0 { level = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi64>
return %0 : memref<?xi64>
Expand All @@ -292,7 +296,9 @@ func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
// CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: return %[[A4]] : memref<?xf64>
// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] val_mem_sz
// CHECK: %[[V:.*]] = memref.subview %[[A4]][0] [%[[S]]] [1]
// CHECK: return %[[V]] : memref<?xf64>
func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
return %0 : memref<?xf64>
Expand All @@ -305,13 +311,14 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: return %[[A4]] : memref<?xf64>
// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] val_mem_sz
// CHECK: %[[V:.*]] = memref.subview %[[A4]][0] [%[[S]]] [1]
// CHECK: return %[[V]] : memref<?xf64>
func.func @sparse_values_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xf64> {
%0 = sparse_tensor.values %arg0 : tensor<?x?x?xf64, #ccoo> to memref<?xf64>
return %0 : memref<?xf64>
}


// CHECK-LABEL: func.func @sparse_indices_coo(
// CHECK-SAME: %[[A0:.*0]]: memref<?xindex>,
// CHECK-SAME: %[[A1:.*1]]: memref<?xindex>,
Expand All @@ -320,7 +327,7 @@ func.func @sparse_values_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xf64> {
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
// CHECK: %[[S0:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
// CHECK: %[[S2:.*]] = arith.divui %[[S0]], %[[C2]] : index
// CHECK: %[[R1:.*]] = memref.subview %[[A3]][0] {{\[}}%[[S2]]] [2] : memref<?xindex> to memref<?xindex, strided<[2]>>
// CHECK: %[[R2:.*]] = memref.cast %[[R1]] : memref<?xindex, strided<[2]>> to memref<?xindex, strided<[?], offset: ?>>
Expand All @@ -337,7 +344,9 @@ func.func @sparse_indices_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex
// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier
// CHECK: return %[[A3]] : memref<?xindex>
// CHECK: %[[S:.*]] = sparse_tensor.storage_specifier.get %[[A5]] crd_mem_sz at 1
// CHECK: %[[V:.*]] = memref.subview %[[A3]][0] [%[[S]]] [1]
// CHECK: return %[[V]] : memref<?xindex>
func.func @sparse_indices_buffer_coo(%arg0: tensor<?x?x?xf64, #ccoo>) -> memref<?xindex> {
%0 = sparse_tensor.coordinates_buffer %arg0 : tensor<?x?x?xf64, #ccoo> to memref<?xindex>
return %0 : memref<?xindex>
Expand Down
Loading
Loading