Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stablehlo] fix crashing on AtenEmbeddingBagSumExample_basic #3389

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ namespace hlo {

using mlir::ConversionPatternRewriter;

// Create chlo::ConstantLikeOp
template <typename T>
Value getConstantLike(OpBuilder &rewriter, Location loc, T constant, Value val);

// Create a 32-bit float constant operator from a float
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val);
Expand Down
66 changes: 18 additions & 48 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,6 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_to_stablehlo;

namespace {

template <typename T>
static Value getConstantLike(OpBuilder &b, Location loc, T constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (isa<mlir::IntegerType>(ty))
return b.getIntegerAttr(ty, constant);
if (isa<mlir::FloatType>(ty))
return b.getFloatAttr(ty, constant);
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
return complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type");
};
return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
val);
}

Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
return b.create<mlir::chlo::ConstantLikeOp>(loc, b.getFloatAttr(ty, constant),
val);
}

} // namespace

LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
mlir::Value &self, mlir::Value &other,
size_t dimSizeIndexBits) {
Expand Down Expand Up @@ -928,7 +900,8 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
"for AtenReciprocalOp");
}

Value oneTensor = getConstantLike(rewriter, op->getLoc(), 1, input);
Value oneTensor =
hlo::getConstantLike<int64_t>(rewriter, op->getLoc(), 1, input);
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(op, outTy, oneTensor, input);
return success();
}
Expand Down Expand Up @@ -1070,12 +1043,8 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
return op->emitError("only float tensor in relu op is supported");
}

Value zeroTensor;
zeroTensor = getConstantLike(
rewriter, op->getLoc(),
APFloat::getZero(cast<mlir::FloatType>(lhsElemTy).getFloatSemantics(),
false),
lhs);
Value zeroTensor =
hlo::getConstantLike<int64_t>(rewriter, op->getLoc(), 0, lhs);
rewriter.replaceOpWithNewOp<stablehlo::MaxOp>(op, lhs, zeroTensor);
return success();
}
Expand All @@ -1102,13 +1071,13 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
return op.emitError("unsupported approximate: ") << approximate;
}

Value one = getConstantLike(rewriter, loc, 1.0, input);
Value two = getConstantLike(rewriter, loc, 2.0, input);
Value three = getConstantLike(rewriter, loc, 3.0, input);
Value half = getConstantLike(rewriter, loc, 0.5, input);
Value one = hlo::getConstantLike(rewriter, loc, 1.0, input);
Value two = hlo::getConstantLike(rewriter, loc, 2.0, input);
Value three = hlo::getConstantLike(rewriter, loc, 3.0, input);
Value half = hlo::getConstantLike(rewriter, loc, 0.5, input);
// 2/pi
Value twoDivPi = getConstantLike(rewriter, loc, M_2_PI, input);
Value t = getConstantLike(rewriter, loc, 0.044715, input);
Value twoDivPi = hlo::getConstantLike(rewriter, loc, M_2_PI, input);
Value t = hlo::getConstantLike(rewriter, loc, 0.044715, input);

// x * 0.5
auto inputMulHalf = rewriter.create<stablehlo::MulOp>(loc, input, half);
Expand Down Expand Up @@ -1147,7 +1116,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);

auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);

Expand All @@ -1169,7 +1138,7 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);

auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
auto logInputOp = rewriter.create<stablehlo::LogOp>(op.getLoc(), input);

Expand Down Expand Up @@ -1764,12 +1733,13 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "Unsupported value of approximate");
}
// Create constant value
Value kAlpha = getConstantLike(rewriter, loc, 0.70710678118654752440, input);
Value kAlpha =
hlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input);
Value cstAlpha0 =
getConstantLike(rewriter, loc, 1.12837916709551257390, input);
Value half = getConstantLike(rewriter, loc, .5, input);
Value one = getConstantLike(rewriter, loc, 1.0, input);
Value negHalf = getConstantLike(rewriter, loc, -0.5, input);
hlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input);
Value half = hlo::getConstantLike(rewriter, loc, .5, input);
Value one = hlo::getConstantLike(rewriter, loc, 1.0, input);
Value negHalf = hlo::getConstantLike(rewriter, loc, -0.5, input);

// Compute
Value kBeta0 =
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace {
static Value createInitialValueForGatherScatterOp(Operation *op,
RankedTensorType constType,
PatternRewriter &rewriter) {
if (!constType.hasStaticShape()) {
return nullptr;
}
auto elementTy = constType.getElementType();
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
if (isa<mlir::FloatType>(elementTy)) {
Expand Down
27 changes: 27 additions & 0 deletions lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
Expand All @@ -24,6 +26,31 @@ using namespace mlir::torch::Torch;
namespace mlir {
namespace hlo {

// Create chlo::ConstantLikeOp
template <typename T>
Value getConstantLike(OpBuilder &rewriter, Location loc, T constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (isa<mlir::IntegerType>(ty))
return rewriter.getIntegerAttr(ty, constant);
if (isa<mlir::FloatType>(ty))
return rewriter.getFloatAttr(ty, constant);
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
return mlir::complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type");
};
return rewriter.create<mlir::chlo::ConstantLikeOp>(
loc, cast<TypedAttr>(getAttr()), val);
}

// Template instantiation
template Value getConstantLike<int64_t>(OpBuilder &rewriter, Location loc,
int64_t constant, Value val);

template Value getConstantLike<double>(OpBuilder &rewriter, Location loc,
double constant, Value val);

// Create a 32-bit float constant operator from a float
Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) {
Expand Down
4 changes: 1 addition & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1442,9 +1442,7 @@
"ElementwiseSoftshrinkStaticModule_basic",
}

STABLEHLO_CRASHING_SET = {
"AtenEmbeddingBagSumExample_basic",
}
STABLEHLO_CRASHING_SET = set()

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
Expand Down
Loading