Skip to content

feat(TosaToNamedLinalg): add FillOp conversion for Bias #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ bool getConstShapeValues(Operation *op,
// returns a small vector of int64_t values that attr contains
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
const int rank);

// Returns the attribute that stores the constant value of a ConstantLike
// operation. Prerequisite is `op` to be a `ConstantLike` operation.
Attribute getConstantAttribute(Operation *op);
} // namespace tosa
} // namespace mlir

Expand Down
79 changes: 78 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/TypeSwitch.h"

#include <numeric>
#include <type_traits>
Expand Down Expand Up @@ -118,6 +119,71 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
/*symbolCount=*/0, sourceDims, rewriter.getContext());
}

static mlir::Value createScalarConstantFromTensor(PatternRewriter &rewriter,
Operation *source,
Value result) {
// Get the constant as the attribute from the constant operation
Attribute value = tosa::getConstantAttribute(source);
auto attr = dyn_cast<SplatElementsAttr>(value);

// Ensure the constant is splat so we can convert to a scalar
if (!attr) {
return Value();
}

// Filter for constants based on Ranked Tensors
auto resultTy = dyn_cast<RankedTensorType>(result.getType());
if (!resultTy) {
return Value();
}

// Create a scalar constant with the same type as the result tensor.
// We assume the ResultType follows the TOSA spec, in that it can be an
// accumulator type that is same as or larger in bitwidth than the splat
// constant.
Value scalarValue =
llvm::TypeSwitch<Attribute, Value>(attr.getSplatValue<Attribute>())
.Case([&](FloatAttr attr) {
return rewriter
// Create a float constant with the same type as the result
// tensor and use the host systems double type as APFloat
// checks bitwidths so in the case of different input -> output
// types the conversion will fail.
.create<arith::ConstantOp>(
source->getLoc(),
FloatAttr::get(resultTy.getElementType(),
attr.getValue().convertToDouble()))
.getResult();
})
.Case([&](IntegerAttr attr) {
// At the moment all profiles are signed, so for the unsigned case
// if it does happen bail out.
if (resultTy.getElementType().isUnsignedInteger()) {
return Value();
}
// Create a scalar that follows the result type. In the case of i8,
// the result can be i32. So we perform the conversion at
// compile-time.
return rewriter
.create<arith::ConstantOp>(
source->getLoc(),
IntegerAttr::get(resultTy.getElementType(),
attr.getValue().getSExtValue()))
.getResult();
})
.Default([](Attribute) { return Value(); });

// Could not create a scalar constant due to an unsupported type
if (!scalarValue) {
return Value();
}

return rewriter
.create<linalg::FillOp>(source->getLoc(), ValueRange{scalarValue},
ValueRange{result})
.getResult(0);
}

// Broadcast the source value to all the outer dimensions of the result value.
// If required, the element type is expanded using an arith.extsi or arith.extf
// operation as appropriate.
Expand All @@ -126,6 +192,17 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
Value result) {
ShapedType resultTy = cast<ShapedType>(result.getType());
const int64_t resultRank = resultTy.getRank();

// Attempt to create a FillOp in linalg if the constant is a splat value.
if (source.getDefiningOp() &&
matchPattern(source.getDefiningOp(), m_Constant())) {
auto scalar = createScalarConstantFromTensor(
rewriter, source.getDefiningOp(), result);
if (scalar) {
return scalar;
}
}

// Creating maps for the input and output of the broacast-like generic op.
SmallVector<AffineMap, 2> indexingMaps;
indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
}
return {};
}

Attribute mlir::tosa::getConstantAttribute(Operation *op) {

if (!op || !op->hasTrait<OpTrait::ConstantLike>())
return Attribute();

if (auto constOp = dyn_cast<ConstOp>(op)) {
return constOp.getValues();
}

// TOSA names constants in the operation as "value" while linalg names them
// with "values". Here we search for both and find the first.
const SmallVector<const char *> possibleAttributes = {"value", "values"};
for (llvm::StringRef name : possibleAttributes) {
if (op->hasAttr(name)) {
return op->getAttr(name);
}
}
return Attribute();
}
57 changes: 57 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,63 @@ func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<2

// -----

// CHECK-LABEL: @conv2d_bias_broadcast_f32
func.func @conv2d_bias_broadcast_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>) -> () {
%bias = "tosa.const"() <{values = dense<4.20> : tensor<28xf32>}> : () -> tensor<28xf32>
// CHECK-DAG: %[[CST:.+]] = arith.constant 4.200000e+00 : f32
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
// CHECK: %[[BIAS:.+]] = linalg.fill
// CHECK-SAME: ins(%[[CST]]
// CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xf32>
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
// CHECK-SAME: outs(%[[BIAS]]
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
return
}

// -----

// CHECK-LABEL: @conv2d_dynamic_batch_bias_broadcast_f32
// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x49x42x27xf32>
func.func @conv2d_dynamic_batch_bias_broadcast_f32(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>) -> () {
%bias = "tosa.const"() <{values = dense<4.20> : tensor<28xf32>}> : () -> tensor<28xf32>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x49x42x27xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x45x40x28xf32>
// CHECK: %[[CST:.+]] = arith.constant 4.200000e+00 : f32
// CHECK: %[[BIAS:.+]] = linalg.fill
// CHECK-SAME: ins(%[[CST]]
// CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<?x45x40x28xf32>
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
// CHECK-SAME: outs(%[[BIAS]]
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x45x40x28xf32>
return
}

// -----

// CHECK-LABEL: @conv2d_bias_broadcast_i8_acc_i32
func.func @conv2d_bias_broadcast_i8_acc_i32(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x3x3x27xi8>) -> () {
%bias = "tosa.const"() <{values = dense<42> : tensor<28xi8>}> : () -> tensor<28xi8>
// CHECK-DAG: %[[CST:.+]] = arith.constant 42 : i32
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
// CHECK: %[[BIAS:.+]] = linalg.fill
// CHECK-SAME: ins(%[[CST]]
// CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xi32>
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
// CHECK-SAME: outs(%[[BIAS]]
%input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xi8>, tensor<28x3x3x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32>
return
}

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

Expand Down