Skip to content

[mlir][mesh, mpi] More on MeshToMPI #129048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 6, 2025
Merged

[mlir][mesh, mpi] More on MeshToMPI #129048

merged 7 commits into from
Mar 6, 2025

Conversation

fschlimb
Copy link
Contributor

@fschlimb fschlimb commented Feb 27, 2025

  • do not create MPI operations if no halo exchange is needed on a given boundary
  • allow returning sharding information by return !mesh.sharding (gets converted into a tuple of tensors)
  • lowering mesh.shard_shape including fixes to the operation itself
  • global symbol static_mpi_rank replaced by an DLTI attribute (now aligned with MPIToLLVM)
  • smaller fixes and some minor cleanup

This is the outcome of my downstream project which can now run (! not only compile to LLVM) the pipeline end-to-end going mesh->MPI->LLVM. I grouped the changes into commits that make reviewing easier (when going commit by commit).

@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2025

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

Changes
  • do not create MPI operations if no halo exchange is needed on a given boundary
  • allow returning sharding information by return !mesh.sharding (gets converted into a tuple of tensors)
  • lowering mesh.shard_shape including fixes to the operation itself
  • global symbol static_mpi_rank replaced by an DLTI attribute (now aligned with MPIToLLVM)
  • smaller fixes and some minor cleanup

Patch is 79.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129048.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+4-4)
  • (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+15-7)
  • (modified) mlir/lib/Conversion/MeshToMPI/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+510-102)
  • (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+11-4)
  • (modified) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+10-9)
  • (modified) mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir (+211-116)
  • (added) mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir (+75)
  • (modified) mlir/test/Dialect/Mesh/ops.mlir (+6-4)
  • (modified) mlir/test/Dialect/Tensor/mesh-spmdization.mlir (+7-4)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..6074e0e8d822c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
   let description = [{
     This pass converts communication operations from the Mesh dialect to the
     MPI dialect.
-    If it finds a global named "static_mpi_rank" it will use that splat value
-    instead of calling MPI_Comm_rank. This allows optimizations like constant
-    shape propagation and fusion because shard/partition sizes depend on the
-    rank.
+    If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
+    use that integer value instead of calling MPI_Comm_rank. This allows
+    optimizations like constant shape propagation and fusion because
+    shard/partition sizes depend on the rank.
   }];
   let dependentDialects = [
     "memref::MemRefDialect",
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 031e6f63bcb42..f59c4c4c67517 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
   }];
 }
 
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
-  let summary = "Get the shard shape of a given process/device.";
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+    Pure, AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Get the shard shape for a given process/device.";
   let description = [{
-    The device/process id is a linearized id of the device/process in the mesh.
+    The device/process id is a multi-index of the device/process in the mesh.
     This operation might be used during spmdization when the shard shape depends
     on (non-constant) values used in `mesh.sharding`.
   }];
   let arguments = (ins
-    DenseI64ArrayAttr:$shape,
+    DenseI64ArrayAttr:$dims,
+    Variadic<Index>:$dims_dynamic,
     Mesh_Sharding:$sharding,
-    Index:$device
+    DenseI64ArrayAttr:$device,
+    Variadic<Index>:$device_dynamic
   );
   let results = (outs Variadic<Index>:$result);
   let assemblyFormat = [{
-      custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
+      `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
+      `sharding` `=` $sharding
+      `device` `=` custom<DynamicIndexList>($device_dynamic, $device)
+      attr-dict `:` type(results)
   }];
   let builders = [
-    OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+    OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
   ];
 }
 
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
index 95815a683f6d6..15560aa61e145 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
   Core
 
   LINK_LIBS PUBLIC
+  MLIRDLTIDialect
   MLIRFuncDialect
   MLIRIR
   MLIRLinalgTransforms
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 48b3764d520c2..84db6d456711c 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,8 +14,13 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,6 +30,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "mesh-to-mpi"
@@ -36,10 +42,32 @@ namespace mlir {
 } // namespace mlir
 
 using namespace mlir;
-using namespace mlir::mesh;
+using namespace mesh;
 
 namespace {
-// Create operations converting a linear index to a multi-dimensional index
+/// Convert vec of OpFoldResults (ints) into vector of Values.
+static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+                                           llvm::ArrayRef<int64_t> statics,
+                                           ValueRange dynamics,
+                                           Type type = Type()) {
+  SmallVector<Value> values;
+  auto dyn = dynamics.begin();
+  Type i64 = b.getI64Type();
+  if (!type)
+    type = i64;
+  assert(i64 == type || b.getIndexType() == type);
+  for (auto s : statics) {
+    values.emplace_back(
+        ShapedType::isDynamic(s)
+            ? *(dyn++)
+            : b.create<arith::ConstantOp>(loc, type,
+                                          i64 == type ? b.getI64IntegerAttr(s)
+                                                      : b.getIndexAttr(s)));
+  }
+  return values;
+};
+
+/// Create operations converting a linear index to a multi-dimensional index.
 static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
                                              Value linearIndex,
                                              ValueRange dimensions) {
@@ -48,23 +76,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
 
   for (int i = n - 1; i >= 0; --i) {
     multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
-    if (i > 0) {
+    if (i > 0)
       linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
-    }
   }
 
   return multiIndex;
 }
 
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
                          ValueRange dimensions) {
 
-  auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
-  auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+  Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
 
   for (int i = multiIndex.size() - 1; i >= 0; --i) {
-    auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+    Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
     linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
     stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
   }
@@ -72,35 +99,179 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
   return linearIndex;
 }
 
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+    if (!shardOp)
+      return failure();
+    auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+    if (!shardingOp)
+      return failure();
+
+    rewriter.replaceOp(op, shardingOp.getResult());
+    return success();
+  }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+///   (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto splitAxes = op.getSplitAxes().getAxes();
+    int64_t maxNAxes = 0;
+    for (auto axes : splitAxes) {
+      maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+    }
+
+    // To hold the split axes, create empty 2d tensor with shape
+    // {splitAxes.size(), max-size-of-split-groups}.
+    // Set trailing elements for smaller split-groups to -1.
+    Location loc = op.getLoc();
+    auto i16 = rewriter.getI16Type();
+    auto i64 = rewriter.getI64Type();
+    int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+    Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+    auto attr = IntegerAttr::get(i16, 0xffff);
+    Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+    resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+                       .getResult(0);
+
+    // explicitly write values into tensor row by row
+    int64_t strides[] = {1, 1};
+    int64_t nSplits = 0;
+    ValueRange empty = {};
+    for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+      int64_t size = axes.size();
+      if (size > 0)
+        ++nSplits;
+      int64_t offs[] = {(int64_t)i, 0};
+      int64_t sizes[] = {1, size};
+      auto tensorType = RankedTensorType::get({size}, i16);
+      auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+      auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+      resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+          loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+    }
+
+    // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+    // Store the halo sizes in the tensor.
+    auto haloSizes =
+        getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+                         adaptor.getDynamicHaloSizes());
+    auto type = RankedTensorType::get({nSplits, 2}, i64);
+    Value resHaloSizes =
+        haloSizes.empty()
+            ? rewriter
+                  .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+                                           i64)
+                  .getResult()
+            : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+                  .getResult();
+
+    // To hold sharded dims offsets, create Tensor with shape {nSplits,
+    // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+    // elements for smaller split-groups to -1. Computing the max size of the
+    // split groups needs using collectiveProcessGroupSize (which needs the
+    // MeshOp)
+    Value resOffsets;
+    if (adaptor.getStaticShardedDimsOffsets().empty()) {
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{0, 0}, i64);
+    } else {
+      SymbolTableCollection symbolTableCollection;
+      auto meshOp = getMesh(op, symbolTableCollection);
+      auto maxSplitSize = 0;
+      for (auto axes : splitAxes) {
+        int64_t splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic);
+        maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+      }
+      assert(maxSplitSize);
+      ++maxSplitSize; // add one for the total size
+
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+      Value zero = rewriter.create<arith::ConstantOp>(
+          loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+      resOffsets =
+          rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+      auto offsets =
+          getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+                           adaptor.getDynamicShardedDimsOffsets());
+      int64_t curr = 0;
+      for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+        int64_t splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+        ++splitSize; // add one for the total size
+        ArrayRef<Value> values(&offsets[curr], splitSize);
+        Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+        int64_t offs[] = {(int64_t)i, 0};
+        int64_t sizes[] = {1, splitSize};
+        resOffsets = rewriter.create<tensor::InsertSliceOp>(
+            loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+        curr += splitSize;
+      }
+    }
+
+    // return a tuple of tensors as defined by type converter
+    SmallVector<Type> resTypes;
+    if (failed(getTypeConverter()->convertType(op.getResult().getType(),
+                                               resTypes)))
+      return failure();
+
+    resSplitAxes =
+        rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+    resHaloSizes =
+        rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
+    resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+
+    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+        op, TupleType::get(op.getContext(), resTypes),
+        ValueRange{resSplitAxes, resHaloSizes, resOffsets});
+
+    return success();
+  }
+};
+
 struct ConvertProcessMultiIndexOp
-    : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public OpConversionPattern<ProcessMultiIndexOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Currently converts its linear index to a multi-dimensional index.
 
     SymbolTableCollection symbolTableCollection;
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     auto meshOp = getMesh(op, symbolTableCollection);
     // For now we only support static mesh shapes
-    if (ShapedType::isDynamicShape(meshOp.getShape())) {
-      return mlir::failure();
-    }
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return failure();
 
     SmallVector<Value> dims;
     llvm::transform(
         meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
-    auto rank =
-        rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+    Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
     auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
 
     // optionally extract subset of mesh axes
-    auto axes = op.getAxes();
+    auto axes = adaptor.getAxes();
     if (!axes.empty()) {
       SmallVector<Value> subIndex;
       for (auto axis : axes) {
@@ -110,32 +281,33 @@ struct ConvertProcessMultiIndexOp
     }
 
     rewriter.replaceOp(op, mIdx);
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertProcessLinearIndexOp
-    : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
-                  mlir::PatternRewriter &rewriter) const override {
-
-    // Finds a global named "static_mpi_rank" it will use that splat value.
-    // Otherwise it defaults to mpi.comm_rank.
-
-    auto loc = op.getLoc();
-    auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
-    if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
-            op, rankOpName)) {
-      if (auto initTnsr = globalOp.getInitialValueAttr()) {
-        auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
-        rewriter.replaceOp(op,
-                           rewriter.create<arith::ConstantIndexOp>(loc, val));
-        return mlir::success();
-      }
+class ConvertProcessLinearIndexOp
+    : public OpConversionPattern<ProcessLinearIndexOp> {
+  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
+
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  // Constructor accepting worldRank
+  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
+                              MLIRContext *context, int64_t worldRank_ = -1)
+      : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
+
+  LogicalResult
+  matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
+      return success();
     }
+
+    // Otherwise call create mpi::CommRankOp
     auto rank =
         rewriter
             .create<mpi::CommRankOp>(
@@ -144,44 +316,43 @@ struct ConvertProcessLinearIndexOp
             .getRank();
     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
                                                     rank);
-    return mlir::success();
+    return success();
   }
 };
 
 struct ConvertNeighborsLinearIndicesOp
-    : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public OpConversionPattern<NeighborsLinearIndicesOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
 
     // Computes the neighbors indices along a split axis by simply
     // adding/subtracting 1 to the current index in that dimension.
     // Assigns -1 if neighbor is out of bounds.
 
-    auto axes = op.getSplitAxes();
+    auto axes = adaptor.getSplitAxes();
     // For now only single axis sharding is supported
-    if (axes.size() != 1) {
-      return mlir::failure();
-    }
+    if (axes.size() != 1)
+      return failure();
 
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     SymbolTableCollection symbolTableCollection;
     auto meshOp = getMesh(op, symbolTableCollection);
-    auto mIdx = op.getDevice();
+    auto mIdx = adaptor.getDevice();
     auto orgIdx = mIdx[axes[0]];
     SmallVector<Value> dims;
     llvm::transform(
         meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
           return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
         });
-    auto dimSz = dims[axes[0]];
-    auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
-    auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
-    auto atBorder = rewriter.create<arith::CmpIOp>(
+    Value dimSz = dims[axes[0]];
+    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
+    Value atBorder = rewriter.create<arith::CmpIOp>(
         loc, arith::CmpIPredicate::sle, orgIdx,
-        rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+        rewriter.create<arith::ConstantIndexOp>(loc, 0));
     auto down = rewriter.create<scf::IfOp>(
         loc, atBorder,
         [&](OpBuilder &builder, Location loc) {
@@ -206,23 +377,161 @@ struct ConvertNeighborsLinearIndicesOp
         [&](OpBuilder &builder, Location loc) {
           SmallVector<Value> tmp = mIdx;
           tmp[axes[0]] =
-              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
-                  .getResult();
+              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
           builder.create<scf::YieldOp>(
               loc, multiToLinearIndex(loc, rewriter, tmp, dims));
         });
     rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
-    return mlir::success();
+    return success();
   }
 };
 
-...
[truncated]

@fschlimb fschlimb changed the title More on MeshToMPI [mlir][mesh, mpi] More on MeshToMPI Feb 27, 2025
@fschlimb
Copy link
Contributor Author

@tkarna FYI

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped a bunch of nit comments, but I'm not qualified to perform an in-depth functional review of this.

Copy link

github-actions bot commented Feb 28, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@fschlimb fschlimb force-pushed the meshtompi branch 2 times, most recently from db651b9 to 0696d7b Compare February 28, 2025 09:57
@fschlimb
Copy link
Contributor Author

Dropped a bunch of nit comments, but I'm not qualified to perform an in-depth functional review of this.

Thanks! I followed all your suggestions (except one).

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Only two minor remarks.

Co-authored-by: Christian Ulmann <christianulmann@gmail.com>
@fschlimb
Copy link
Contributor Author

thanks @tkarna for your review!

Copy link
Contributor

@AntonLydike AntonLydike left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really cool work! Thanks @fschlimb for the contribution! Looks good from an MPI side imho!

@fschlimb fschlimb merged commit 6e57326 into llvm:main Mar 6, 2025
11 checks passed
@vitalybuka
Copy link
Collaborator

vitalybuka commented Mar 6, 2025

There is a heap-use-after-free https://lab.llvm.org/buildbot/#/builders/169/builds/9129/steps/11/logs/stdio

My mistake, it's probably from #124713

jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
- do not create MPI operations if no halo exchange is needed
- allow returning sharding information through `!mesh.sharding`
  (gets converted into a tuple of tensors)
- lowering `mesh.shard_shape` including fixes to the operation itself
- global symbol `static_mpi_rank` replaced by an DLTI attribute
  (now aligned with MPIToLLVM)
- smaller fixes and some minor cleanup

---------

Co-authored-by: Christian Ulmann <christianulmann@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants