Skip to content

Commit

Permalink
Merge pull request #89 from Xilinx/enadjara.xten_nn_reduce_mean_to_torch
Browse files Browse the repository at this point in the history
feat: legalization from xten_nn.reduce_mean to aten.mean.dim
  • Loading branch information
ehsan-toosi authored Sep 20, 2024
2 parents 44eee5c + 6f7009d commit 9c4ed9d
Show file tree
Hide file tree
Showing 4 changed files with 360 additions and 2 deletions.
7 changes: 6 additions & 1 deletion include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -562,20 +562,25 @@ def XtenNN_ConvTransposeOp: XTenNN_Op<"ConvTranspose",[Pure, TosaExtension]> {
let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }];
}

def XtenNN_ReduceMeanOp: XTenNN_Op<"reduce_mean", [Pure, TosaExtension]> {
def XtenNN_ReduceMeanOp: XTenNN_Op<"reduce_mean", [
Pure, TosaExtension,
InferTensorTypeAdaptor]> {
let summary = "Reduce Mean operation";
let description = [{
This operation is equivalent to `onnx.ReduceMean` and computes the mean of
the input tensor's elements along the provided axes.
}];

let arguments = (ins
AnyRankedTensor:$input,
DenseI64ArrayAttr:$axes,
I64Attr:$keepdims
);

let results = (outs
AnyRankedTensor:$output
);

let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }];
}

Expand Down
16 changes: 16 additions & 0 deletions lib/Conversion/XTenNNToTorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,21 @@ convTranspose2dToTorch(ConvTransposeOp op, ConvTransposeOp::Adaptor adaptor,
->getResults();
}

std::optional<ValueRange>
reduceMeanToTorch(ReduceMeanOp op, ReduceMeanOp::Adaptor adaptor,
ArrayRef<Type> types, ValueRange values,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
auto noneConst = rewriter.create<Torch::ConstantNoneOp>(loc);
auto keepdims =
rewriter.create<Torch::ConstantBoolOp>(loc, adaptor.getKeepdims());
auto axes = Torch::toTorchList(loc, rewriter, adaptor.getAxes().vec());
return rewriter
.create<Torch::AtenMeanDimOp>(loc, types[0], values[0], axes, keepdims,
noneConst)
->getResults();
}

std::optional<ValueRange> resizeToTorch(ResizeOp op, ResizeOp::Adaptor adaptor,
ArrayRef<Type> types, ValueRange values,
ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -439,6 +454,7 @@ struct ConvertXTenNNToTorch
patterns.add<ApplyXTenNNToTorch<ResizeOp, resizeToTorch>>(context);
patterns.add<ApplyXTenNNToTorch<ConvTransposeOp, convTranspose2dToTorch>>(
context);
patterns.add<ApplyXTenNNToTorch<ReduceMeanOp, reduceMeanToTorch>>(context);
if (failed(applyPartialConversion(funcOp, target, std::move(patterns))))
signalPassFailure();
}
Expand Down
51 changes: 50 additions & 1 deletion lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
Expand All @@ -26,6 +27,7 @@
#include "xten/Dialect/XTenNN/IR/XTenNNBase.h"
#include "xten/Dialect/XTenNN/IR/XTenNNOps.h"
#include "xten/Dialect/XTenNN/Interfaces/EnclaveOpInterfaces.h"
#include <cstdint>

using namespace mlir;
using namespace amd::xten_nn;
Expand Down Expand Up @@ -264,7 +266,9 @@ ParseResult SubgraphOp::parse(OpAsmParser &p, OperationState &result) {
return parseEnclaveOp(p, result);
}

void SubgraphOp::print(OpAsmPrinter &p) { printEnclaveOp(p, *this); }
void SubgraphOp::print(OpAsmPrinter &p) {
printEnclaveOp(p, *this);
}

LogicalResult SubgraphOp::verify() {
Block *optBody = this->getOptionalEnclaveBody();
Expand Down Expand Up @@ -593,3 +597,48 @@ bool TopK::isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r) {
getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]);
return sameElementType && succeeded(verifyCompatibleShapes(l, r));
}

LogicalResult ReduceMeanOp::inferReturnTypeComponents(
MLIRContext * /*context*/, std::optional<Location> location,
ReduceMeanOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {

auto inTy = cast<RankedTensorType>(adaptor.getInput().getType());
auto keepDims = adaptor.getKeepdims();
auto axes = adaptor.getAxes();

// Sanitize axes
llvm::SmallVector<int64_t> newAxes;
for (auto axis : axes) {
// onnx spec: axis: [-r, r-1]
if (axis < -inTy.getRank() || axis >= inTy.getRank()) {
return emitOptionalError(location,
"expected axis to be within [-rank,rank) (where "
"rank is the rank of the input)");
}

// normalize axis: [0, r)
if (axis < 0) {
axis += inTy.getRank();
}

assert((axis >= 0 && axis < inTy.getRank()) && "axis has invalid value");
newAxes.push_back(axis);
}

SmallVector<int64_t, 4> outputShape;
auto inputShape = inTy.getShape();
for (auto [idx, dim] : llvm::enumerate(inputShape)) {
if (llvm::is_contained(axes, idx)) {
if (keepDims) {
outputShape.push_back(1);
}
} else {
outputShape.push_back(dim);
}
}

inferredReturnShapes.push_back(
ShapedTypeComponents(outputShape, inTy.getElementType()));
return success();
}
Loading

0 comments on commit 9c4ed9d

Please sign in to comment.