Skip to content

Commit 0b58685

Browse files
committed
Add support elementwise ops in Wg to Sg distribute pass
1 parent 77c8d21 commit 0b58685

File tree

2 files changed

+1194
-0
lines changed

2 files changed

+1194
-0
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@
88
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
99

1010
#include "mlir/Dialect/Affine/Utils.h"
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/Dialect/Arith/Utils/Utils.h"
1213
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1314
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1415
#include "mlir/Dialect/Index/IR/IndexOps.h"
16+
#include "mlir/Dialect/Math/IR/Math.h"
1517
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1618
#include "mlir/Dialect/Utils/IndexingUtils.h"
1719
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1820
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
1921
#include "mlir/Transforms/DialectConversion.h"
22+
#include <optional>
2023

2124
namespace mlir {
2225
namespace xegpu {
@@ -314,6 +317,179 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
314317
}
315318
};
316319

320+
// This pattern matches elementwise ops (unary/binary) in math/arith dialects
321+
// with 1D or 2D vector types
322+
template <typename Op>
323+
struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
324+
using OpConversionPattern<Op>::OpConversionPattern;
325+
using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;
326+
327+
LogicalResult
328+
matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
329+
ConversionPatternRewriter &rewriter) const override {
330+
// All operands/results must be 1D or 2D vectors
331+
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
332+
if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2))
333+
return rewriter.notifyMatchFailure(
334+
op, "Result type is not a 1D or 2D vector");
335+
336+
ArrayRef<int64_t> shape = resultType.getShape();
337+
for (Value operand : op->getOperands()) {
338+
auto operandType = dyn_cast<VectorType>(operand.getType());
339+
if (!operandType || operandType.getRank() != resultType.getRank() ||
340+
operandType.getShape() != shape) {
341+
return rewriter.notifyMatchFailure(
342+
op, "Operand type is not a 1D or 2D vector with the same shape as "
343+
"result type");
344+
}
345+
}
346+
347+
// Check for layout attribute with sgLayout
348+
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
349+
if (!layout || !layout.getSgLayout())
350+
return rewriter.notifyMatchFailure(
351+
op, "Operation does not have a valid layout attribute for subgroup "
352+
"distribution");
353+
354+
// Extract sgShape from layout
355+
SmallVector<int64_t> sgShape;
356+
if (auto sgDataAttr = layout.getSgData()) {
357+
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
358+
} else {
359+
auto sgLayoutArr = layout.getSgLayout();
360+
sgShape.reserve(shape.size());
361+
for (size_t i = 0; i < shape.size(); ++i) {
362+
assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
363+
sgShape.push_back(shape[i] / sgLayoutArr[i]);
364+
}
365+
}
366+
367+
// Each operand is a list of values
368+
size_t numVariants = adaptor.getOperands().empty()
369+
? 0
370+
: adaptor.getOperands().front().size();
371+
for (auto &operandVec : adaptor.getOperands())
372+
if (operandVec.size() != numVariants)
373+
return rewriter.notifyMatchFailure(
374+
op, "Operand lists have mismatched sizes");
375+
376+
SmallVector<Value> newResults;
377+
378+
auto origResultType = dyn_cast<VectorType>(op->getResult(0).getType());
379+
VectorType newResultType =
380+
origResultType
381+
? VectorType::get(sgShape, origResultType.getElementType())
382+
: VectorType::get(sgShape, resultType.getElementType());
383+
384+
for (size_t i = 0; i < numVariants; ++i) {
385+
SmallVector<Value> operands;
386+
for (auto &operandVec : adaptor.getOperands())
387+
operands.push_back(operandVec[i]);
388+
389+
auto newOp = rewriter.create<Op>(op.getLoc(), newResultType, operands);
390+
391+
// Copy all attributes except "layout", and add "layout_result_0" with
392+
// sgLayout/data dropped
393+
for (auto attr : op->getAttrs()) {
394+
if (attr.getName() != "layout")
395+
newOp->setAttr(attr.getName(), attr.getValue());
396+
}
397+
newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData());
398+
399+
newResults.push_back(newOp.getResult());
400+
}
401+
402+
rewriter.replaceOpWithMultiple(op, {newResults});
403+
return success();
404+
}
405+
};
406+
407+
// ---- ARITH ops ----
408+
using WgToSgAddFOp = WgToSgElementwiseOp<arith::AddFOp>;
409+
using WgToSgSubFOp = WgToSgElementwiseOp<arith::SubFOp>;
410+
using WgToSgNegFOp = WgToSgElementwiseOp<arith::NegFOp>;
411+
using WgToSgAddIOp = WgToSgElementwiseOp<arith::AddIOp>;
412+
using WgToSgSubIOp = WgToSgElementwiseOp<arith::SubIOp>;
413+
using WgToSgMulFOp = WgToSgElementwiseOp<arith::MulFOp>;
414+
using WgToSgMulIOp = WgToSgElementwiseOp<arith::MulIOp>;
415+
using WgToSgShLIOp = WgToSgElementwiseOp<arith::ShLIOp>;
416+
using WgToSgShRSIOp = WgToSgElementwiseOp<arith::ShRSIOp>;
417+
using WgToSgShRUIOp = WgToSgElementwiseOp<arith::ShRUIOp>;
418+
using WgToSgDivFOp = WgToSgElementwiseOp<arith::DivFOp>;
419+
using WgToSgDivSIOp = WgToSgElementwiseOp<arith::DivSIOp>;
420+
using WgToSgDivUIOp = WgToSgElementwiseOp<arith::DivUIOp>;
421+
using WgToSgMaximumFOp = WgToSgElementwiseOp<arith::MaximumFOp>;
422+
using WgToSgMinimumFOp = WgToSgElementwiseOp<arith::MinimumFOp>;
423+
using WgToSgRemSIOp = WgToSgElementwiseOp<arith::RemSIOp>;
424+
using WgToSgRemUIOp = WgToSgElementwiseOp<arith::RemUIOp>;
425+
using WgToSgTruncFOp = WgToSgElementwiseOp<arith::TruncFOp>;
426+
using WgToSgTruncIOp = WgToSgElementwiseOp<arith::TruncIOp>;
427+
using WgToSgExtFOp = WgToSgElementwiseOp<arith::ExtFOp>;
428+
using WgToSgExtSIOp = WgToSgElementwiseOp<arith::ExtSIOp>;
429+
using WgToSgExtUIOp = WgToSgElementwiseOp<arith::ExtUIOp>;
430+
using WgToSgSIToFPOp = WgToSgElementwiseOp<arith::SIToFPOp>;
431+
using WgToSgUIToFPOp = WgToSgElementwiseOp<arith::UIToFPOp>;
432+
using WgToSgFPToSIOp = WgToSgElementwiseOp<arith::FPToSIOp>;
433+
using WgToSgFPToUIOp = WgToSgElementwiseOp<arith::FPToUIOp>;
434+
using WgToSgIndexCastUIOp = WgToSgElementwiseOp<arith::IndexCastUIOp>;
435+
using WgToSgIndexCastOp = WgToSgElementwiseOp<arith::IndexCastOp>;
436+
using WgToSgBitcastOp = WgToSgElementwiseOp<arith::BitcastOp>;
437+
using WgToSgCmpIOp = WgToSgElementwiseOp<arith::CmpIOp>;
438+
using WgToSgCmpFOp = WgToSgElementwiseOp<arith::CmpFOp>;
439+
using WgToSgAndIOp = WgToSgElementwiseOp<arith::AndIOp>;
440+
using WgToSgCeilDivSIOp = WgToSgElementwiseOp<arith::CeilDivSIOp>;
441+
using WgToSgCeilDivUIOp = WgToSgElementwiseOp<arith::CeilDivUIOp>;
442+
using WgToSgFloorDivSIOp = WgToSgElementwiseOp<arith::FloorDivSIOp>;
443+
using WgToSgMaxNumFOp = WgToSgElementwiseOp<arith::MaxNumFOp>;
444+
using WgToSgMaxSIOp = WgToSgElementwiseOp<arith::MaxSIOp>;
445+
using WgToSgMaxUIOp = WgToSgElementwiseOp<arith::MaxUIOp>;
446+
using WgToSgMinNumFOp = WgToSgElementwiseOp<arith::MinNumFOp>;
447+
using WgToSgMinSIOp = WgToSgElementwiseOp<arith::MinSIOp>;
448+
using WgToSgMinUIOp = WgToSgElementwiseOp<arith::MinUIOp>;
449+
using WgToSgOrIOp = WgToSgElementwiseOp<arith::OrIOp>;
450+
using WgToSgRemFOp = WgToSgElementwiseOp<arith::RemFOp>;
451+
using WgToSgSelectOp = WgToSgElementwiseOp<arith::SelectOp>;
452+
using WgToSgXOrIOp = WgToSgElementwiseOp<arith::XOrIOp>;
453+
454+
// ---- MATH ops ----
455+
using WgToSgExpOp = WgToSgElementwiseOp<math::ExpOp>;
456+
using WgToSgSqrtOp = WgToSgElementwiseOp<math::SqrtOp>;
457+
using WgToSgAbsFOp = WgToSgElementwiseOp<math::AbsFOp>;
458+
using WgToSgCosOp = WgToSgElementwiseOp<math::CosOp>;
459+
using WgToSgCoshOp = WgToSgElementwiseOp<math::CoshOp>;
460+
using WgToSgAcosOp = WgToSgElementwiseOp<math::AcosOp>;
461+
using WgToSgAcoshOp = WgToSgElementwiseOp<math::AcoshOp>;
462+
using WgToSgSinOp = WgToSgElementwiseOp<math::SinOp>;
463+
using WgToSgSinhOp = WgToSgElementwiseOp<math::SinhOp>;
464+
using WgToSgAsinOp = WgToSgElementwiseOp<math::AsinOp>;
465+
using WgToSgAsinhOp = WgToSgElementwiseOp<math::AsinhOp>;
466+
using WgToSgTanOp = WgToSgElementwiseOp<math::TanOp>;
467+
using WgToSgTanhOp = WgToSgElementwiseOp<math::TanhOp>;
468+
using WgToSgAtanOp = WgToSgElementwiseOp<math::AtanOp>;
469+
using WgToSgAtan2Op = WgToSgElementwiseOp<math::Atan2Op>;
470+
using WgToSgAtanhOp = WgToSgElementwiseOp<math::AtanhOp>;
471+
using WgToSgErfOp = WgToSgElementwiseOp<math::ErfOp>;
472+
using WgToSgLogOp = WgToSgElementwiseOp<math::LogOp>;
473+
using WgToSgLog2Op = WgToSgElementwiseOp<math::Log2Op>;
474+
using WgToSgFloorOp = WgToSgElementwiseOp<math::FloorOp>;
475+
using WgToSgCeilOp = WgToSgElementwiseOp<math::CeilOp>;
476+
using WgToSgPowFOp = WgToSgElementwiseOp<math::PowFOp>;
477+
using WgToSgRsqrtOp = WgToSgElementwiseOp<math::RsqrtOp>;
478+
using WgToSgAbsIOp = WgToSgElementwiseOp<math::AbsIOp>;
479+
using WgToSgCbrtOp = WgToSgElementwiseOp<math::CbrtOp>;
480+
using WgToSgCopySignOp = WgToSgElementwiseOp<math::CopySignOp>;
481+
using WgToSgCtPopOp = WgToSgElementwiseOp<math::CtPopOp>;
482+
using WgToSgErfcOp = WgToSgElementwiseOp<math::ErfcOp>;
483+
using WgToSgExp2Op = WgToSgElementwiseOp<math::Exp2Op>;
484+
using WgToSgExpM1Op = WgToSgElementwiseOp<math::ExpM1Op>;
485+
using WgToSgFPowIOp = WgToSgElementwiseOp<math::FPowIOp>;
486+
using WgToSgIPowIOp = WgToSgElementwiseOp<math::IPowIOp>;
487+
using WgToSgLog10Op = WgToSgElementwiseOp<math::Log10Op>;
488+
using WgToSgLog1pOp = WgToSgElementwiseOp<math::Log1pOp>;
489+
using WgToSgRoundOp = WgToSgElementwiseOp<math::RoundOp>;
490+
using WgToSgRoundEvenOp = WgToSgElementwiseOp<math::RoundEvenOp>;
491+
using WgToSgTruncOp = WgToSgElementwiseOp<math::TruncOp>;
492+
317493
} // namespace
318494

319495
namespace mlir {
@@ -322,6 +498,27 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
322498
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323499
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
324500
patterns.getContext());
501+
// Add elementwise operations that can be distributed to subgroups
502+
patterns.add<
503+
WgToSgAddFOp, WgToSgSubFOp, WgToSgExpOp, WgToSgSqrtOp, WgToSgAbsFOp,
504+
WgToSgCosOp, WgToSgCoshOp, WgToSgAcosOp, WgToSgAcoshOp, WgToSgSinOp,
505+
WgToSgSinhOp, WgToSgAsinOp, WgToSgAsinhOp, WgToSgTanOp, WgToSgTanhOp,
506+
WgToSgAtanOp, WgToSgAtan2Op, WgToSgAtanhOp, WgToSgErfOp, WgToSgLogOp,
507+
WgToSgLog2Op, WgToSgFloorOp, WgToSgCeilOp, WgToSgPowFOp, WgToSgRsqrtOp,
508+
WgToSgNegFOp, WgToSgAddIOp, WgToSgSubIOp, WgToSgMulFOp, WgToSgMulIOp,
509+
WgToSgShLIOp, WgToSgShRSIOp, WgToSgShRUIOp, WgToSgDivFOp, WgToSgDivSIOp,
510+
WgToSgDivUIOp, WgToSgMaximumFOp, WgToSgMinimumFOp, WgToSgRemSIOp,
511+
WgToSgRemUIOp, WgToSgTruncFOp, WgToSgTruncIOp, WgToSgExtFOp,
512+
WgToSgExtSIOp, WgToSgExtUIOp, WgToSgSIToFPOp, WgToSgUIToFPOp,
513+
WgToSgFPToSIOp, WgToSgFPToUIOp, WgToSgIndexCastUIOp, WgToSgIndexCastOp,
514+
WgToSgBitcastOp, WgToSgCmpIOp, WgToSgCmpFOp, WgToSgAndIOp,
515+
WgToSgCeilDivSIOp, WgToSgCeilDivUIOp, WgToSgFloorDivSIOp, WgToSgMaxNumFOp,
516+
WgToSgMaxSIOp, WgToSgMaxUIOp, WgToSgMinNumFOp, WgToSgMinSIOp,
517+
WgToSgMinUIOp, WgToSgOrIOp, WgToSgRemFOp, WgToSgSelectOp, WgToSgXOrIOp,
518+
WgToSgAbsIOp, WgToSgCbrtOp, WgToSgCopySignOp, WgToSgCtPopOp, WgToSgErfcOp,
519+
WgToSgExp2Op, WgToSgExpM1Op, WgToSgFPowIOp, WgToSgIPowIOp, WgToSgLog10Op,
520+
WgToSgLog1pOp, WgToSgRoundOp, WgToSgRoundEvenOp, WgToSgTruncOp>(
521+
patterns.getContext());
325522
}
326523
} // namespace xegpu
327524
} // namespace mlir
@@ -368,6 +565,32 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
368565
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
369566
return isLegal(layout);
370567
});
568+
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
569+
[=](Operation *op) -> std::optional<bool> {
570+
// Handle unary and binary operations
571+
if (op->getNumOperands() < 1 || op->getNumOperands() > 2)
572+
return true;
573+
574+
// check if input and output are vectors
575+
VectorType resultType =
576+
dyn_cast<VectorType>(op->getResult(0).getType());
577+
if (!resultType || resultType.getRank() != 2)
578+
return true;
579+
580+
// Check if all operands are vectors
581+
for (Value operand : op->getOperands()) {
582+
VectorType operandType = dyn_cast<VectorType>(operand.getType());
583+
if (!operandType || operandType.getRank() != 2 ||
584+
operandType.getShape() != resultType.getShape()) {
585+
return true;
586+
}
587+
}
588+
589+
// check layout attribute
590+
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
591+
op->getAttrOfType<xegpu::LayoutAttr>("layout"));
592+
return isLegal(layout);
593+
});
371594

372595
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
373596

0 commit comments

Comments
 (0)