8
8
#include " mlir/Dialect/XeGPU/Transforms/Passes.h"
9
9
10
10
#include " mlir/Dialect/Affine/Utils.h"
11
+ #include " mlir/Dialect/Arith/IR/Arith.h"
11
12
#include " mlir/Dialect/Arith/Utils/Utils.h"
12
13
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
13
14
#include " mlir/Dialect/Index/IR/IndexDialect.h"
14
15
#include " mlir/Dialect/Index/IR/IndexOps.h"
16
+ #include " mlir/Dialect/Math/IR/Math.h"
15
17
#include " mlir/Dialect/MemRef/IR/MemRef.h"
16
18
#include " mlir/Dialect/Utils/IndexingUtils.h"
17
19
#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
18
20
#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
19
21
#include " mlir/Transforms/DialectConversion.h"
22
+ #include < optional>
20
23
21
24
namespace mlir {
22
25
namespace xegpu {
@@ -314,6 +317,179 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
314
317
}
315
318
};
316
319
320
+ // This pattern matches elementwise ops (unary/binary) in math/arith dialects
321
+ // with 1D or 2D vector types
322
+ template <typename Op>
323
+ struct WgToSgElementwiseOp : public OpConversionPattern <Op> {
324
+ using OpConversionPattern<Op>::OpConversionPattern;
325
+ using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;
326
+
327
+ LogicalResult
328
+ matchAndRewrite (Op op, OneToNOpAdaptor adaptor,
329
+ ConversionPatternRewriter &rewriter) const override {
330
+ // All operands/results must be 1D or 2D vectors
331
+ auto resultType = dyn_cast<VectorType>(op.getResult ().getType ());
332
+ if (!resultType || (resultType.getRank () != 1 && resultType.getRank () != 2 ))
333
+ return rewriter.notifyMatchFailure (
334
+ op, " Result type is not a 1D or 2D vector" );
335
+
336
+ ArrayRef<int64_t > shape = resultType.getShape ();
337
+ for (Value operand : op->getOperands ()) {
338
+ auto operandType = dyn_cast<VectorType>(operand.getType ());
339
+ if (!operandType || operandType.getRank () != resultType.getRank () ||
340
+ operandType.getShape () != shape) {
341
+ return rewriter.notifyMatchFailure (
342
+ op, " Operand type is not a 1D or 2D vector with the same shape as "
343
+ " result type" );
344
+ }
345
+ }
346
+
347
+ // Check for layout attribute with sgLayout
348
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr (" layout" ));
349
+ if (!layout || !layout.getSgLayout ())
350
+ return rewriter.notifyMatchFailure (
351
+ op, " Operation does not have a valid layout attribute for subgroup "
352
+ " distribution" );
353
+
354
+ // Extract sgShape from layout
355
+ SmallVector<int64_t > sgShape;
356
+ if (auto sgDataAttr = layout.getSgData ()) {
357
+ sgShape = llvm::to_vector_of<int64_t >(sgDataAttr.asArrayRef ());
358
+ } else {
359
+ auto sgLayoutArr = layout.getSgLayout ();
360
+ sgShape.reserve (shape.size ());
361
+ for (size_t i = 0 ; i < shape.size (); ++i) {
362
+ assert (sgLayoutArr[i] != 0 && " sgLayout elements must be non-zero" );
363
+ sgShape.push_back (shape[i] / sgLayoutArr[i]);
364
+ }
365
+ }
366
+
367
+ // Each operand is a list of values
368
+ size_t numVariants = adaptor.getOperands ().empty ()
369
+ ? 0
370
+ : adaptor.getOperands ().front ().size ();
371
+ for (auto &operandVec : adaptor.getOperands ())
372
+ if (operandVec.size () != numVariants)
373
+ return rewriter.notifyMatchFailure (
374
+ op, " Operand lists have mismatched sizes" );
375
+
376
+ SmallVector<Value> newResults;
377
+
378
+ auto origResultType = dyn_cast<VectorType>(op->getResult (0 ).getType ());
379
+ VectorType newResultType =
380
+ origResultType
381
+ ? VectorType::get (sgShape, origResultType.getElementType ())
382
+ : VectorType::get (sgShape, resultType.getElementType ());
383
+
384
+ for (size_t i = 0 ; i < numVariants; ++i) {
385
+ SmallVector<Value> operands;
386
+ for (auto &operandVec : adaptor.getOperands ())
387
+ operands.push_back (operandVec[i]);
388
+
389
+ auto newOp = rewriter.create <Op>(op.getLoc (), newResultType, operands);
390
+
391
+ // Copy all attributes except "layout", and add "layout_result_0" with
392
+ // sgLayout/data dropped
393
+ for (auto attr : op->getAttrs ()) {
394
+ if (attr.getName () != " layout" )
395
+ newOp->setAttr (attr.getName (), attr.getValue ());
396
+ }
397
+ newOp->setAttr (" layout_result_0" , layout.dropSgLayoutAndData ());
398
+
399
+ newResults.push_back (newOp.getResult ());
400
+ }
401
+
402
+ rewriter.replaceOpWithMultiple (op, {newResults});
403
+ return success ();
404
+ }
405
+ };
406
+
407
+ // ---- ARITH ops ----
408
+ using WgToSgAddFOp = WgToSgElementwiseOp<arith::AddFOp>;
409
+ using WgToSgSubFOp = WgToSgElementwiseOp<arith::SubFOp>;
410
+ using WgToSgNegFOp = WgToSgElementwiseOp<arith::NegFOp>;
411
+ using WgToSgAddIOp = WgToSgElementwiseOp<arith::AddIOp>;
412
+ using WgToSgSubIOp = WgToSgElementwiseOp<arith::SubIOp>;
413
+ using WgToSgMulFOp = WgToSgElementwiseOp<arith::MulFOp>;
414
+ using WgToSgMulIOp = WgToSgElementwiseOp<arith::MulIOp>;
415
+ using WgToSgShLIOp = WgToSgElementwiseOp<arith::ShLIOp>;
416
+ using WgToSgShRSIOp = WgToSgElementwiseOp<arith::ShRSIOp>;
417
+ using WgToSgShRUIOp = WgToSgElementwiseOp<arith::ShRUIOp>;
418
+ using WgToSgDivFOp = WgToSgElementwiseOp<arith::DivFOp>;
419
+ using WgToSgDivSIOp = WgToSgElementwiseOp<arith::DivSIOp>;
420
+ using WgToSgDivUIOp = WgToSgElementwiseOp<arith::DivUIOp>;
421
+ using WgToSgMaximumFOp = WgToSgElementwiseOp<arith::MaximumFOp>;
422
+ using WgToSgMinimumFOp = WgToSgElementwiseOp<arith::MinimumFOp>;
423
+ using WgToSgRemSIOp = WgToSgElementwiseOp<arith::RemSIOp>;
424
+ using WgToSgRemUIOp = WgToSgElementwiseOp<arith::RemUIOp>;
425
+ using WgToSgTruncFOp = WgToSgElementwiseOp<arith::TruncFOp>;
426
+ using WgToSgTruncIOp = WgToSgElementwiseOp<arith::TruncIOp>;
427
+ using WgToSgExtFOp = WgToSgElementwiseOp<arith::ExtFOp>;
428
+ using WgToSgExtSIOp = WgToSgElementwiseOp<arith::ExtSIOp>;
429
+ using WgToSgExtUIOp = WgToSgElementwiseOp<arith::ExtUIOp>;
430
+ using WgToSgSIToFPOp = WgToSgElementwiseOp<arith::SIToFPOp>;
431
+ using WgToSgUIToFPOp = WgToSgElementwiseOp<arith::UIToFPOp>;
432
+ using WgToSgFPToSIOp = WgToSgElementwiseOp<arith::FPToSIOp>;
433
+ using WgToSgFPToUIOp = WgToSgElementwiseOp<arith::FPToUIOp>;
434
+ using WgToSgIndexCastUIOp = WgToSgElementwiseOp<arith::IndexCastUIOp>;
435
+ using WgToSgIndexCastOp = WgToSgElementwiseOp<arith::IndexCastOp>;
436
+ using WgToSgBitcastOp = WgToSgElementwiseOp<arith::BitcastOp>;
437
+ using WgToSgCmpIOp = WgToSgElementwiseOp<arith::CmpIOp>;
438
+ using WgToSgCmpFOp = WgToSgElementwiseOp<arith::CmpFOp>;
439
+ using WgToSgAndIOp = WgToSgElementwiseOp<arith::AndIOp>;
440
+ using WgToSgCeilDivSIOp = WgToSgElementwiseOp<arith::CeilDivSIOp>;
441
+ using WgToSgCeilDivUIOp = WgToSgElementwiseOp<arith::CeilDivUIOp>;
442
+ using WgToSgFloorDivSIOp = WgToSgElementwiseOp<arith::FloorDivSIOp>;
443
+ using WgToSgMaxNumFOp = WgToSgElementwiseOp<arith::MaxNumFOp>;
444
+ using WgToSgMaxSIOp = WgToSgElementwiseOp<arith::MaxSIOp>;
445
+ using WgToSgMaxUIOp = WgToSgElementwiseOp<arith::MaxUIOp>;
446
+ using WgToSgMinNumFOp = WgToSgElementwiseOp<arith::MinNumFOp>;
447
+ using WgToSgMinSIOp = WgToSgElementwiseOp<arith::MinSIOp>;
448
+ using WgToSgMinUIOp = WgToSgElementwiseOp<arith::MinUIOp>;
449
+ using WgToSgOrIOp = WgToSgElementwiseOp<arith::OrIOp>;
450
+ using WgToSgRemFOp = WgToSgElementwiseOp<arith::RemFOp>;
451
+ using WgToSgSelectOp = WgToSgElementwiseOp<arith::SelectOp>;
452
+ using WgToSgXOrIOp = WgToSgElementwiseOp<arith::XOrIOp>;
453
+
454
+ // ---- MATH ops ----
455
+ using WgToSgExpOp = WgToSgElementwiseOp<math::ExpOp>;
456
+ using WgToSgSqrtOp = WgToSgElementwiseOp<math::SqrtOp>;
457
+ using WgToSgAbsFOp = WgToSgElementwiseOp<math::AbsFOp>;
458
+ using WgToSgCosOp = WgToSgElementwiseOp<math::CosOp>;
459
+ using WgToSgCoshOp = WgToSgElementwiseOp<math::CoshOp>;
460
+ using WgToSgAcosOp = WgToSgElementwiseOp<math::AcosOp>;
461
+ using WgToSgAcoshOp = WgToSgElementwiseOp<math::AcoshOp>;
462
+ using WgToSgSinOp = WgToSgElementwiseOp<math::SinOp>;
463
+ using WgToSgSinhOp = WgToSgElementwiseOp<math::SinhOp>;
464
+ using WgToSgAsinOp = WgToSgElementwiseOp<math::AsinOp>;
465
+ using WgToSgAsinhOp = WgToSgElementwiseOp<math::AsinhOp>;
466
+ using WgToSgTanOp = WgToSgElementwiseOp<math::TanOp>;
467
+ using WgToSgTanhOp = WgToSgElementwiseOp<math::TanhOp>;
468
+ using WgToSgAtanOp = WgToSgElementwiseOp<math::AtanOp>;
469
+ using WgToSgAtan2Op = WgToSgElementwiseOp<math::Atan2Op>;
470
+ using WgToSgAtanhOp = WgToSgElementwiseOp<math::AtanhOp>;
471
+ using WgToSgErfOp = WgToSgElementwiseOp<math::ErfOp>;
472
+ using WgToSgLogOp = WgToSgElementwiseOp<math::LogOp>;
473
+ using WgToSgLog2Op = WgToSgElementwiseOp<math::Log2Op>;
474
+ using WgToSgFloorOp = WgToSgElementwiseOp<math::FloorOp>;
475
+ using WgToSgCeilOp = WgToSgElementwiseOp<math::CeilOp>;
476
+ using WgToSgPowFOp = WgToSgElementwiseOp<math::PowFOp>;
477
+ using WgToSgRsqrtOp = WgToSgElementwiseOp<math::RsqrtOp>;
478
+ using WgToSgAbsIOp = WgToSgElementwiseOp<math::AbsIOp>;
479
+ using WgToSgCbrtOp = WgToSgElementwiseOp<math::CbrtOp>;
480
+ using WgToSgCopySignOp = WgToSgElementwiseOp<math::CopySignOp>;
481
+ using WgToSgCtPopOp = WgToSgElementwiseOp<math::CtPopOp>;
482
+ using WgToSgErfcOp = WgToSgElementwiseOp<math::ErfcOp>;
483
+ using WgToSgExp2Op = WgToSgElementwiseOp<math::Exp2Op>;
484
+ using WgToSgExpM1Op = WgToSgElementwiseOp<math::ExpM1Op>;
485
+ using WgToSgFPowIOp = WgToSgElementwiseOp<math::FPowIOp>;
486
+ using WgToSgIPowIOp = WgToSgElementwiseOp<math::IPowIOp>;
487
+ using WgToSgLog10Op = WgToSgElementwiseOp<math::Log10Op>;
488
+ using WgToSgLog1pOp = WgToSgElementwiseOp<math::Log1pOp>;
489
+ using WgToSgRoundOp = WgToSgElementwiseOp<math::RoundOp>;
490
+ using WgToSgRoundEvenOp = WgToSgElementwiseOp<math::RoundEvenOp>;
491
+ using WgToSgTruncOp = WgToSgElementwiseOp<math::TruncOp>;
492
+
317
493
} // namespace
318
494
319
495
namespace mlir {
@@ -322,6 +498,27 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
322
498
patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323
499
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
324
500
patterns.getContext ());
501
+ // Add elementwise operations that can be distributed to subgroups
502
+ patterns.add <
503
+ WgToSgAddFOp, WgToSgSubFOp, WgToSgExpOp, WgToSgSqrtOp, WgToSgAbsFOp,
504
+ WgToSgCosOp, WgToSgCoshOp, WgToSgAcosOp, WgToSgAcoshOp, WgToSgSinOp,
505
+ WgToSgSinhOp, WgToSgAsinOp, WgToSgAsinhOp, WgToSgTanOp, WgToSgTanhOp,
506
+ WgToSgAtanOp, WgToSgAtan2Op, WgToSgAtanhOp, WgToSgErfOp, WgToSgLogOp,
507
+ WgToSgLog2Op, WgToSgFloorOp, WgToSgCeilOp, WgToSgPowFOp, WgToSgRsqrtOp,
508
+ WgToSgNegFOp, WgToSgAddIOp, WgToSgSubIOp, WgToSgMulFOp, WgToSgMulIOp,
509
+ WgToSgShLIOp, WgToSgShRSIOp, WgToSgShRUIOp, WgToSgDivFOp, WgToSgDivSIOp,
510
+ WgToSgDivUIOp, WgToSgMaximumFOp, WgToSgMinimumFOp, WgToSgRemSIOp,
511
+ WgToSgRemUIOp, WgToSgTruncFOp, WgToSgTruncIOp, WgToSgExtFOp,
512
+ WgToSgExtSIOp, WgToSgExtUIOp, WgToSgSIToFPOp, WgToSgUIToFPOp,
513
+ WgToSgFPToSIOp, WgToSgFPToUIOp, WgToSgIndexCastUIOp, WgToSgIndexCastOp,
514
+ WgToSgBitcastOp, WgToSgCmpIOp, WgToSgCmpFOp, WgToSgAndIOp,
515
+ WgToSgCeilDivSIOp, WgToSgCeilDivUIOp, WgToSgFloorDivSIOp, WgToSgMaxNumFOp,
516
+ WgToSgMaxSIOp, WgToSgMaxUIOp, WgToSgMinNumFOp, WgToSgMinSIOp,
517
+ WgToSgMinUIOp, WgToSgOrIOp, WgToSgRemFOp, WgToSgSelectOp, WgToSgXOrIOp,
518
+ WgToSgAbsIOp, WgToSgCbrtOp, WgToSgCopySignOp, WgToSgCtPopOp, WgToSgErfcOp,
519
+ WgToSgExp2Op, WgToSgExpM1Op, WgToSgFPowIOp, WgToSgIPowIOp, WgToSgLog10Op,
520
+ WgToSgLog1pOp, WgToSgRoundOp, WgToSgRoundEvenOp, WgToSgTruncOp>(
521
+ patterns.getContext ());
325
522
}
326
523
} // namespace xegpu
327
524
} // namespace mlir
@@ -368,6 +565,32 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
368
565
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr (" layout" ));
369
566
return isLegal (layout);
370
567
});
568
+ target.addDynamicallyLegalDialect <math::MathDialect, arith::ArithDialect>(
569
+ [=](Operation *op) -> std::optional<bool > {
570
+ // Handle unary and binary operations
571
+ if (op->getNumOperands () < 1 || op->getNumOperands () > 2 )
572
+ return true ;
573
+
574
+ // check if input and output are vectors
575
+ VectorType resultType =
576
+ dyn_cast<VectorType>(op->getResult (0 ).getType ());
577
+ if (!resultType || resultType.getRank () != 2 )
578
+ return true ;
579
+
580
+ // Check if all operands are vectors
581
+ for (Value operand : op->getOperands ()) {
582
+ VectorType operandType = dyn_cast<VectorType>(operand.getType ());
583
+ if (!operandType || operandType.getRank () != 2 ||
584
+ operandType.getShape () != resultType.getShape ()) {
585
+ return true ;
586
+ }
587
+ }
588
+
589
+ // check layout attribute
590
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
591
+ op->getAttrOfType <xegpu::LayoutAttr>(" layout" ));
592
+ return isLegal (layout);
593
+ });
371
594
372
595
target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
373
596
0 commit comments