Skip to content

Commit

Permalink
Create some helper functions to emit constant op for a specific type (l…
Browse files Browse the repository at this point in the history
…lvm#7)

* emitConstantOp with a given type

* Helper functions to create infinity constants

* Use new constant helper functions for MaxPoolSingleOut

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
  • Loading branch information
doru1004 authored Mar 5, 2020
1 parent 8e1b30e commit 8a992b6
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 135 deletions.
53 changes: 25 additions & 28 deletions src/conversion/onnx_to_krnl/math/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto elementType = result_types[0];

auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto two = emitConstantOp(rewriter, loc, elementType, 2);
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
Expand All @@ -127,8 +127,8 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto elementType = result_types[0];

auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto two = emitConstantOp(rewriter, loc, elementType, 2);
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
Expand All @@ -152,8 +152,8 @@ Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
Value operand = operands[0];
auto elementType = result_types[0];

auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
auto result = rewriter.create<DivFOp>(
Expand Down Expand Up @@ -184,8 +184,8 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
auto elementType = result_types[0];

auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);

Expand Down Expand Up @@ -217,8 +217,8 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,

auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto lessThanZero =
Expand Down Expand Up @@ -246,7 +246,7 @@ Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0];
auto elementType = result_types[0];

auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
Expand All @@ -271,7 +271,7 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,

auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto lessThanZero =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
Expand Down Expand Up @@ -301,7 +301,7 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
auto elementType = result_types[0];

auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
auto exp = rewriter.create<ExpOp>(loc, operand);
Expand All @@ -328,7 +328,7 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
Value operand = operands[0];
auto elementType = result_types[0];

auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto result = rewriter.create<DivFOp>(loc, one, operand);

return result;
Expand All @@ -347,7 +347,7 @@ Value mapToLowerScalarOp<ONNXSoftplusOp>(
auto elementType = result_types[0];

auto exp = rewriter.create<ExpOp>(loc, operand);
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto add = rewriter.create<AddFOp>(loc, exp, one);
auto result = rewriter.create<LogOp>(loc, add);

Expand All @@ -367,7 +367,7 @@ Value mapToLowerScalarOp<ONNXSoftsignOp>(
auto elementType = result_types[0];

auto abs = rewriter.create<AbsFOp>(loc, operand);
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto add = rewriter.create<AddFOp>(loc, abs, one);
auto result = rewriter.create<DivFOp>(loc, operand, add);

Expand All @@ -384,19 +384,18 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,

auto loc = op->getLoc();
Value operand = operands[0];
Type element_type = operands.front().getType();
Type elementType = operands.front().getType();
// TODO: unsigned int should be supported separately?
if (element_type.isa<IntegerType>()) {
if (elementType.isa<IntegerType>()) {
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
// ConstantOp 1,
// COnstantOp -1)
// ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0),
// ConstantOp 0,
// %Y)
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
auto minusOne =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(-1));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto minusOne = emitConstantOp(rewriter, loc, elementType, -1);
auto plusPredicate =
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, operand, zero);
auto plusSelect =
Expand All @@ -406,18 +405,16 @@ Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
auto result =
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
return result;
} else if (element_type.isa<FloatType>()) {
} else if (elementType.isa<FloatType>()) {
// %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0),
// ConstantOp 1,
// ConstantOp -1)
// ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0),
// ConstantOp 0,
// %Y)
auto zero =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto minusOne =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0f));
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto one = emitConstantOp(rewriter, loc, elementType, 1);
auto minusOne = emitConstantOp(rewriter, loc, elementType, -1);
auto plusPredicate =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
auto plusSelect =
Expand Down
3 changes: 1 addition & 2 deletions src/conversion/onnx_to_krnl/math/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
}

// Initialize the output of A*B
auto zero = rewriter.create<ConstantOp>(
loc, FloatAttr::get(memRefType.getElementType(), 0));
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
rewriter.create<StoreOp>(loc, zero, alloc, loopMNIVs);

// Compute A*B
Expand Down
11 changes: 1 addition & 10 deletions src/conversion/onnx_to_krnl/math/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
auto memRefShape = memRefType.getShape();

// A value zero
Value zero;
if (elementType.isa<IntegerType>()) {
zero = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(memRefType.getElementType(), 0));
} else if (elementType.isa<FloatType>()) {
zero = rewriter.create<ConstantOp>(
loc, FloatAttr::get(memRefType.getElementType(), 0));
} else {
emitError(loc, "unsupported element type");
}
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);

// Insert an allocation and deallocation for the result of this operation.
Value alloc;
Expand Down
54 changes: 14 additions & 40 deletions src/conversion/onnx_to_krnl/math/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,27 @@ using namespace mlir;

// Identity values
template <>
float getIdentityValue<float, ONNXReduceMaxOp>(){
return (float)-std::numeric_limits<float>::infinity();
Value getIdentityValue<ONNXReduceMaxOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitNegativeInfinityConstantOp(rewriter, loc, type);
}

template <>
int getIdentityValue<int, ONNXReduceMaxOp>(){
return std::numeric_limits<int>::min();
Value getIdentityValue<ONNXReduceMinOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitPositiveInfinityConstantOp(rewriter, loc, type);
}

template <>
float getIdentityValue<float, ONNXReduceMinOp>(){
return (float)std::numeric_limits<float>::infinity();
Value getIdentityValue<ONNXReduceProdOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitConstantOp(rewriter, loc, type, 1);
}

template <>
int getIdentityValue<int, ONNXReduceMinOp>(){
return std::numeric_limits<int>::max();
}

template <>
float getIdentityValue<float, ONNXReduceProdOp>(){
return (float)1.0;
}

template <>
int getIdentityValue<int, ONNXReduceProdOp>(){
return 1;
}

template <>
float getIdentityValue<float, ONNXReduceSumOp>(){
return (float)0;
}

template <>
int getIdentityValue<int, ONNXReduceSumOp>(){
return 0;
Value getIdentityValue<ONNXReduceSumOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitConstantOp(rewriter, loc, type, 0);
}

// Scalar ops
Expand Down Expand Up @@ -234,18 +218,8 @@ struct ONNXReductionOpLowering : public ConversionPattern {
loopIVs.push_back(arg);
}

Value identity;
if (elementOutType.isa<FloatType>()) {
identity = rewriter.create<ConstantOp>(
loc, FloatAttr::get(elementOutType,
getIdentityValue<float, ONNXReductionOp>()));
} else if (elementOutType.isa<IntegerType>()) {
identity = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(elementOutType,
getIdentityValue<int, ONNXReductionOp>()));
} else {
emitError(loc, "unsupported element type");
}
Value identity =
getIdentityValue<ONNXReductionOp>(rewriter, loc, elementOutType);
rewriter.create<StoreOp>(loc, identity, alloc, loopIVs);

// Define an Krnl loop to do reduction.
Expand Down
3 changes: 1 addition & 2 deletions src/conversion/onnx_to_krnl/math/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value zero =
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
Value negInfinity = rewriter.create<ConstantOp>(
loc,
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
Expand Down
3 changes: 1 addition & 2 deletions src/conversion/onnx_to_krnl/nn/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
int64_t kernelsPerGroup = floor(kernelShape[0] / group);
auto kernelsPerGroupValue =
rewriter.create<ConstantIndexOp>(loc, kernelsPerGroup);
auto zero = rewriter.create<ConstantOp>(
loc, FloatAttr::get(memRefType.getElementType(), 0));
auto zero = emitConstantOp(rewriter, loc, memRefType.getElementType(), 0);
Value subchannels;
if (kernelShape[1] < 0) {
subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
Expand Down
24 changes: 5 additions & 19 deletions src/conversion/onnx_to_krnl/nn/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ using namespace mlir;

// Identity values
template <>
float getIdentityValue<float, ONNXMaxPoolSingleOutOp>() {
return (float)-std::numeric_limits<float>::infinity();
}

template <>
int getIdentityValue<int, ONNXMaxPoolSingleOutOp>() {
return std::numeric_limits<int>::min();
Value getIdentityValue<ONNXMaxPoolSingleOutOp>(
ConversionPatternRewriter &rewriter, Location loc, Type type) {
return emitNegativeInfinityConstantOp(rewriter, loc, type);
}

template <>
Expand Down Expand Up @@ -204,18 +200,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
resultIndices.emplace_back(outerLoops.getInductionVar(i));

// 2.1 Emit: R[n][c][r1][r2] = negative_infinity;
Value identity;
if (resultElementType.isa<FloatType>()) {
identity = rewriter.create<ConstantOp>(
loc, FloatAttr::get(resultElementType,
getIdentityValue<float, ONNXMaxPoolSingleOutOp>()));
} else if (resultElementType.isa<IntegerType>()) {
identity = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(resultElementType,
getIdentityValue<int, ONNXMaxPoolSingleOutOp>()));
} else {
emitError(loc, "unsupported element type");
}
Value identity = getIdentityValue<ONNXMaxPoolSingleOutOp>(
rewriter, loc, resultElementType);
rewriter.create<StoreOp>(loc, identity, alloc, resultIndices);

// 2.2 Define inner loops.
Expand Down
Loading

0 comments on commit 8a992b6

Please sign in to comment.