Skip to content

Commit

Permalink
[transform] add primitive to decompose vector ops to smaller ones (al…
Browse files Browse the repository at this point in the history
  • Loading branch information
wyzero authored Jan 17, 2023
1 parent ac9ef8c commit 6b067d7
Show file tree
Hide file tree
Showing 4 changed files with 508 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/include/mlir/Dialect/Utils/IndexingUtils.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/utils.h"

Expand Down Expand Up @@ -1204,11 +1206,316 @@ DiagnosedSilenceableFailure InlineReductionInitializerOp::apply(
return DiagnosedSilenceableFailure(success());
}

//===---------------------------------------------------------------------===//
// DecomposeVectorsOp
//===---------------------------------------------------------------------===//

namespace {

Optional<SmallVector<int64_t>> computeShapeRatio(ArrayRef<int64_t> superShape,
ArrayRef<int64_t> subShape) {
if (superShape.size() < subShape.size()) {
return Optional<SmallVector<int64_t>>();
}

// Starting from the end, compute the integer divisors.
std::vector<int64_t> result;
result.reserve(superShape.size());
for (auto [superSize, subSize] :
llvm::zip(llvm::reverse(superShape), llvm::reverse(subShape))) {
assert(superSize > 0 && "superSize must be > 0");
assert(subSize > 0 && "subSize must be > 0");

// If integral division does not occur, return and let the caller decide.
if (superSize % subSize != 0) return llvm::None;
result.push_back(superSize / subSize);
}

// At this point we computed the ratio (in reverse) for the common
// size. Fill with the remaining entries from the super-vector shape (still in
// reverse).
int commonSize = subShape.size();
std::copy(superShape.rbegin() + commonSize, superShape.rend(),
std::back_inserter(result));

assert(result.size() == superShape.size() &&
"super to sub shape ratio is not of the same size as the super rank");

// Reverse again to get it back in the proper order and return.
return SmallVector<int64_t>{result.rbegin(), result.rend()};
}

SmallVector<int64_t> computeStrides(ArrayRef<int64_t> shape,
ArrayRef<int64_t> sizes) {
int64_t rank = shape.size();
// Compute the count for each dimension.
SmallVector<int64_t> sliceDimCounts(rank);
for (int64_t r = 0; r < rank; ++r)
sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
// Use that to compute the slice stride for each dimension.
SmallVector<int64_t> sliceStrides(rank);
sliceStrides[rank - 1] = 1;
for (int64_t r = rank - 2; r >= 0; --r)
sliceStrides[r] = sliceStrides[r + 1] * sliceDimCounts[r + 1];
return sliceStrides;
}

SmallVector<int64_t> computeElementOffsetsFromVectorSliceOffsets(
ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
SmallVector<int64_t> result;
for (auto it : llvm::zip(vectorOffsets, sizes))
result.push_back(std::get<0>(it) * std::get<1>(it));
return result;
}

/// During unrolling from `originalShape` to `targetShape` return the offset for
/// the slice `index`.
static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> targetShape,
int64_t index) {
SmallVector<int64_t> dstSliceStrides =
computeStrides(originalShape, targetShape);
SmallVector<int64_t> vectorOffsets = delinearize(dstSliceStrides, index);
SmallVector<int64_t> elementOffsets =
computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
return elementOffsets;
}

struct DecomposeVectorInOutsOfForOp : public OpRewritePattern<scf::ForOp> {
explicit DecomposeVectorInOutsOfForOp(
MLIRContext* context, const vector::UnrollVectorOptions& options,
PatternBenefit benefit = 1)
: mlir::OpRewritePattern<scf::ForOp>(context, benefit),
options_(options) {}

struct VectorDecomposeInfo {
VectorType dstVecType;
int64_t startIdx;
int64_t numSubVectors;
SmallVector<SmallVector<int64_t>> strides;
SmallVector<SmallVector<int64_t>> offsets;
};

LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter& rewriter) const final {
Location loc = forOp.getLoc();
if (failed(options_.filterConstraint(forOp))) return failure();
auto maybeTargetShape = options_.nativeShape(forOp);
if (!maybeTargetShape) return failure();
auto& targetShape = *maybeTargetShape;
int numNewInitArgs = 0;
VectorType targetVectorType;
DenseMap<int, VectorDecomposeInfo> candidateValueMap;
for (const auto& en : llvm::enumerate(forOp.getInitArgs())) {
auto ty = en.value().getType().dyn_cast<VectorType>();
Optional<SmallVector<int64_t>> maybeShapeRatio;
if (!ty ||
!(maybeShapeRatio = computeShapeRatio(ty.getShape(), targetShape)) ||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
++numNewInitArgs;
continue;
}
// TODO(wyzero): support multiple iter args with different vector type.
if (targetVectorType && targetVectorType != ty) return failure();
targetVectorType = ty;
auto& item = candidateValueMap[en.index()];
item.dstVecType = ty;
item.numSubVectors = computeMaxLinearIndex(*maybeShapeRatio);
item.startIdx = numNewInitArgs;
SmallVector<int64_t> strides(targetShape.size(), 1);
for (int i = 0; i < item.numSubVectors; ++i) {
auto offsets = getVectorOffset(ty.getShape(), targetShape, i);
item.strides.push_back(strides);
item.offsets.push_back(offsets);
++numNewInitArgs;
}
}
if (candidateValueMap.empty()) return failure();

SmallVector<Value> newIterArgs;
for (const auto& en : llvm::enumerate(forOp.getInitArgs())) {
auto it = candidateValueMap.find(en.index());
if (it == candidateValueMap.end()) {
newIterArgs.push_back(en.value());
} else {
auto& item = it->second;
for (int i = 0; i < item.numSubVectors; ++i) {
newIterArgs.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, en.value(), item.offsets[i], targetShape, item.strides[i]));
}
}
}

scf::ForOp newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newIterArgs);
newForOp->setAttrs(forOp->getAttrs());
Block& newBlock = newForOp.getRegion().front();

// 1, merge block args to restore the old vector type
rewriter.setInsertionPointToStart(&newBlock);
SmallVector<Value> newBlockArgs;
newBlockArgs.push_back(newForOp.getInductionVar());
// skip the first block arg: loop induction var.
size_t blockArgIdx = 1;
for (int i = 0; i < forOp->getNumResults(); ++i) {
auto it = candidateValueMap.find(i);
if (it == candidateValueMap.end()) {
newBlockArgs.push_back(newBlock.getArgument(blockArgIdx++));
continue;
}
auto& item = it->second;
Value mergedVec = rewriter.create<arith::ConstantOp>(
loc, item.dstVecType, rewriter.getZeroAttr(item.dstVecType));
for (int subIdx = 0; subIdx < item.numSubVectors; ++subIdx) {
mergedVec = rewriter.create<vector::InsertStridedSliceOp>(
loc, newBlock.getArgument(blockArgIdx++), mergedVec,
item.offsets[subIdx], item.strides[subIdx]);
}
newBlockArgs.push_back(mergedVec);
}

// 2, clone ops inside the region of old for op.
Block& oldBlock = forOp.getRegion().front();
rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockArgs);
auto oldYieldOp = &newBlock.back();
rewriter.setInsertionPointAfter(oldYieldOp);

// 3, split the yield result of old for op
SmallVector<Value> newYieldValues;
for (const auto& en : llvm::enumerate(oldYieldOp->getOperands())) {
auto it = candidateValueMap.find(en.index());
if (it == candidateValueMap.end()) {
newYieldValues.push_back(en.value());
continue;
}
auto& item = it->second;
for (int subIdx = 0; subIdx < item.numSubVectors; ++subIdx) {
newYieldValues.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
loc, en.value(), item.offsets[subIdx], targetShape,
item.strides[subIdx]));
}
}
rewriter.create<scf::YieldOp>(loc, newYieldValues);
rewriter.eraseOp(oldYieldOp);

// 4, merge return value of new for op.
rewriter.setInsertionPointAfter(newForOp);
size_t resultIdx = 0;
SmallVector<Value> newResults;
for (const auto& en : llvm::enumerate(forOp->getResults())) {
auto it = candidateValueMap.find(en.index());
if (it == candidateValueMap.end()) {
newResults.push_back(newForOp->getResult(resultIdx++));
continue;
}
auto& item = it->second;
Value mergedVec = rewriter.create<arith::ConstantOp>(
loc, item.dstVecType, rewriter.getZeroAttr(item.dstVecType));
for (int subIdx = 0; subIdx < item.numSubVectors; ++subIdx) {
mergedVec = rewriter.create<vector::InsertStridedSliceOp>(
loc, newForOp->getResult(resultIdx++), mergedVec,
item.offsets[subIdx], item.strides[subIdx]);
}
newResults.push_back(mergedVec);
}
rewriter.replaceOp(forOp, newResults);
return success();
}

private:
vector::UnrollVectorOptions options_;
};

} // namespace

void DecomposeVectorsOp::build(OpBuilder& builder, OperationState& result,
Value target, int64_t vectorSize) {
MLIRContext* ctx = builder.getContext();
result.addOperands(target);
result.addAttribute(
DecomposeVectorsOp::getVectorSizeAttrName(result.name),
builder.getIntegerAttr(builder.getIntegerType(64), vectorSize));
result.addTypes({});
}

DiagnosedSilenceableFailure DecomposeVectorsOp::applyToOne(
Operation* target, SmallVectorImpl<Operation*>& results,
transform::TransformState& state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
return mlir::emitDefiniteFailure(target,
"applies only to isolated-from-above "
"targets because it needs to apply "
"patterns greedily");
}

MLIRContext* ctx = getContext();
RewritePatternSet patterns(ctx);
// decompose outerproduct to bcast + fma ops.
vector::populateVectorContractLoweringPatterns(patterns);

vector::UnrollVectorOptions options;
auto isTargetType = [&](Type ty) {
auto castedTy = ty.dyn_cast<VectorType>();
return castedTy && castedTy.getRank() > 0 &&
castedTy.getShape()[castedTy.getRank() - 1] %
this->getVectorSize() ==
0;
};
auto getVectorTypeOfForOp = [&](Operation* op) -> VectorType {
if (!isa<scf::ForOp>(op)) return nullptr;
VectorType vectorTy;
for (auto ty : op->getResultTypes()) {
if (!isTargetType(ty)) continue;
if (vectorTy && vectorTy != ty.dyn_cast<VectorType>()) return nullptr;
vectorTy = ty.dyn_cast<VectorType>();
}
return vectorTy;
};
options.setFilterConstraint([&](Operation* op) {
if (getVectorTypeOfForOp(op)) return success();
if (isa<vector::TransferReadOp, vector::TransferWriteOp>(op))
return success();
if (op->getNumResults() != 1) return failure();
if (op->getDialect()->getTypeID() != TypeID::get<vector::VectorDialect>() &&
op->getDialect()->getTypeID() != TypeID::get<arith::ArithDialect>())
return failure();
return success(isTargetType(op->getResult(0).getType()));
});
vector::UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
[&](Operation* op) -> Optional<SmallVector<int64_t, 4>> {
VectorType targetVectorTy;
if (auto vectorTy = getVectorTypeOfForOp(op)) {
targetVectorTy = vectorTy;
} else if (isa<vector::TransferWriteOp>(op)) {
targetVectorTy = op->getOperand(0).getType().cast<VectorType>();
} else {
targetVectorTy = op->getResult(0).getType().cast<VectorType>();
}
SmallVector<int64_t, 4> nativeShape(targetVectorTy.getRank(), 1);
nativeShape[targetVectorTy.getRank() - 1] = this->getVectorSize();
return nativeShape;
};
options.setNativeShapeFn(nativeShapeFn);
vector::populateVectorUnrollPatterns(patterns, options);
patterns.insert<DecomposeVectorInOutsOfForOp>(ctx, options);

// some clean up patterns.
vector::populateVectorToVectorCanonicalizationPatterns(patterns);

// Apply everything.
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
return DiagnosedSilenceableFailure::definiteFailure();

results.assign({target});
return DiagnosedSilenceableFailure(success());
}

} // namespace transform_dialect

void registerTransformDialectCommonExtension(DialectRegistry& registry) {
registry
.addExtensions< ::mlir::disc_ral::transform_dialect::CommonExtensions>();
.addExtensions<::mlir::disc_ral::transform_dialect::CommonExtensions>();
}

} // namespace disc_ral
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,58 @@ def InlineReductionInitializerOp : Op<Transform_Dialect, "disc.inline_reduction_
let cppNamespace = "::mlir::disc_ral::transform_dialect";
}

def DecomposeVectorsOp : Op<Transform_Dialect, "disc.decompose_vectors",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
TransformOpInterface]> {
let description = [{
Decompose vector ops and related vector transform_read/write ops into fine-grained-size vector ops.

Without this decomposition, llvm backend will mis-allocate vector register for some vector ops
with specific shape (e.g. 8x12xf32 failed to map to hardware register). This is a workaround for this.

Examples:
convert
```
%0 = vector.transfer_read %arg0[..., %c0] : vector<8xf32>
%1 = vector.transfer_read %arg1[..., %c0] : vector<8xf32>
%2 = vector.transfer_read %arg2[..., %c0] : vector<8xf32>
%3 = vector.fma %0, %1, %2 : vector<8xf32>
vector.transfer_write %3, %arg2[..., %c0] : vector<8xf32>
```
to:
```
%0_0 = vector.transfer_read %arg0[..., %c0] : vector<4xf32>
%0_1 = vector.transfer_read %arg0[..., %c4] : vector<4xf32>
%1_0 = vector.transfer_read %arg1[..., %c0] : vector<4xf32>
%1_1 = vector.transfer_read %arg1[..., %c4] : vector<4xf32>
%2_0 = vector.transfer_read %arg2[..., %c0] : vector<4xf32>
%2_1 = vector.transfer_read %arg2[..., %c4] : vector<4xf32>
%3_0 = vector.fma %0_0, %1_0, %2_0 : vector<4xf32>
%3_1 = vector.fma %0_1, %1_1, %2_1 : vector<4xf32>
vector.transfer_write %3_0, %arg2[..., %c0] : vector<4xf32>
vector.transfer_write %3_1, %arg2[..., %c4] : vector<4xf32>
```
}];

let arguments = (ins PDL_Operation:$target,
I64Attr:$vector_size);
let results = (outs PDL_Operation:$result);

let assemblyFormat = "$target attr-dict";
let cppNamespace = "::mlir::disc_ral::transform_dialect";

let builders = [
OpBuilder<(ins "Value":$target, "int64_t":$vector_size)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}

#endif // DISC_TRANSFORM_OPS_EXT
Loading

0 comments on commit 6b067d7

Please sign in to comment.