Skip to content

Commit

Permalink
bugfixes for S2T PyTorch example (alibaba#951)
Browse files Browse the repository at this point in the history
* bugfixes for S2T PyTorch example

* fix clang format

* add aten::div.Tensor_model shape infer

* add more fixes
  • Loading branch information
Tanyo Kwok authored Jan 11, 2023
1 parent e42b61f commit 6e1352a
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,13 @@ class ShapePropagator : public PropertyPropBase {
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
AT_ASSERT(maybe_tensor_types->size() >= 2);
auto dtype = getPromotedTypeForArithmeticOp(node);
#if PYTORCH_VERSION_GE(1, 9)
if (node->matches(
"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor")) {
return {broadcast(*maybe_tensor_types, dtype)};
}
#endif // PYTORCH_VERSION_GE(1, 9)

if ((node->kind() == aten::div || node->kind() == aten::div_) &&
dtype.has_value() &&
c10::isIntegralType(dtype.value(), false)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::conv2d",
"aten::cos",
"aten::div",
"aten::detach",
"aten::einsum",
"aten::embedding",
"aten::empty",
"aten::eq",
"aten::flip",
"aten::gt",
"aten::ge",
"aten::lt",
Expand Down Expand Up @@ -114,6 +116,7 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {
"aten::reshape",
"aten::roll",
"aten::rsqrt",
"aten::rsub",
"aten::select",
"aten::selu",
"aten::selu_",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,8 @@ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
if (!matchPattern(op.dims(), m_TorchConstantIntList(dimListInt)))
return rewriter.notifyMatchFailure(
op, "Only constant dims are currently supported");

auto dims = mhlo::toPositiveDims(dimListInt, selfTy.getRank());
std::copy(dims.begin(), dims.end(), dimListInt.begin());
rewriter.replaceOpWithNewOp<mlir::mhlo::ReverseOp>(
op,
getTypeConverter()->convertType(op.getType()),
Expand Down Expand Up @@ -1457,6 +1458,7 @@ class DiscConvertTorchToMhlo
INSERT_UNARY_CONVERT_PATTERN(AtenToDtypeLayoutOp);
INSERT_UNARY_CONVERT_PATTERN(AtenToPrimDeviceOp);
INSERT_UNARY_CONVERT_PATTERN(AtenTypeAsOp);
INSERT_UNARY_CONVERT_PATTERN(AtenDetachOp);
#undef INSERT_UNARY_CONVERT_PATTERN

#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \
Expand Down
2 changes: 1 addition & 1 deletion pytorch_blade/tests/torchscript/basics.graph
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,4 @@ graph(%p1 : Float(*, *, *, device=cuda:0),
%p6 : Float(*, *, *, device=cuda:0)):
// CHECK: Float(*, *, *, device=cuda:0) = aten::gru_cell(%p1, %p2, %p3, %p4, %p5, %p6)
%1 : Float(32, 32, 10, device=cuda:0) = aten::gru_cell(%p1, %p2, %p3, %p4, %p5, %p6)
return (%1)
return (%1)
10 changes: 10 additions & 0 deletions pytorch_blade/tests/torchscript/since_1_10.graph
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ graph(%p1 : Float(*, *, *, device=cuda:0)):
%3 : Tensor = aten::to(%p1, %1, %2, %2)
return (%3)

// aten.div.Tensor_mode
// CHECK-LABEL: graph
graph(%p1 : Long(*, *, *, requires_grad=0),
%p2 : Long(*, *, *, requires_grad=0),
%p3 : str):
// CHECK: Long(*, *, *) = aten::div
%1 : Tensor = aten::div(%p1, %p2, %p3)
return (%1)


// aten::to.prim_Device
// CHECK-LABEL: graph
graph(%p1 : Float(device=cuda:0)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Expand Down Expand Up @@ -126,6 +127,7 @@ void configureGpuToROCDLConversionLegality(ConversionTarget& target) {
target.addLegalDialect<ROCDL::ROCDLDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
target.addIllegalDialect<cf::ControlFlowDialect>();
target.addIllegalDialect<arith::ArithDialect, math::MathDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::FAbsOp, LLVM::FCeilOp,
LLVM::FFloorOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
LLVM::PowOp, LLVM::SinOp, LLVM::SqrtOp>();
Expand Down Expand Up @@ -162,6 +164,8 @@ void populateGpuToROCDLConversionPatterns(LLVMTypeConverter& converter,
"__ocml_ceil_f64");
patterns.add<OpToFuncCallLowering<math::CosOp>>(converter, "__ocml_cos_f32",
"__ocml_cos_f64");
patterns.add<OpToFuncCallLowering<math::CopySignOp>>(
converter, "__ocml_copysign_f32", "__ocml_copysign_f64");
patterns.add<OpToFuncCallLowering<math::ExpOp>>(converter, "__ocml_exp_f32",
"__ocml_exp_f64");
patterns.add<OpToFuncCallLowering<math::ExpM1Op>>(
Expand Down

0 comments on commit 6e1352a

Please sign in to comment.