-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][mesh, mpi] More on MeshToMPI #129048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) Changes
Patch is 79.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129048.diff 10 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..6074e0e8d822c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
let description = [{
This pass converts communication operations from the Mesh dialect to the
MPI dialect.
- If it finds a global named "static_mpi_rank" it will use that splat value
- instead of calling MPI_Comm_rank. This allows optimizations like constant
- shape propagation and fusion because shard/partition sizes depend on the
- rank.
+ If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
+ use that integer value instead of calling MPI_Comm_rank. This allows
+ optimizations like constant shape propagation and fusion because
+ shard/partition sizes depend on the rank.
}];
let dependentDialects = [
"memref::MemRefDialect",
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 031e6f63bcb42..f59c4c4c67517 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
}];
}
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
- let summary = "Get the shard shape of a given process/device.";
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+ Pure, AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Get the shard shape for a given process/device.";
let description = [{
- The device/process id is a linearized id of the device/process in the mesh.
+ The device/process id is a multi-index of the device/process in the mesh.
This operation might be used during spmdization when the shard shape depends
on (non-constant) values used in `mesh.sharding`.
}];
let arguments = (ins
- DenseI64ArrayAttr:$shape,
+ DenseI64ArrayAttr:$dims,
+ Variadic<Index>:$dims_dynamic,
Mesh_Sharding:$sharding,
- Index:$device
+ DenseI64ArrayAttr:$device,
+ Variadic<Index>:$device_dynamic
);
let results = (outs Variadic<Index>:$result);
let assemblyFormat = [{
- custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
+ `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
+ `sharding` `=` $sharding
+ `device` `=` custom<DynamicIndexList>($device_dynamic, $device)
+ attr-dict `:` type(results)
}];
let builders = [
- OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+ OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
];
}
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
index 95815a683f6d6..15560aa61e145 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
Core
LINK_LIBS PUBLIC
+ MLIRDLTIDialect
MLIRFuncDialect
MLIRIR
MLIRLinalgTransforms
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 48b3764d520c2..84db6d456711c 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,8 +14,13 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,6 +30,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "mesh-to-mpi"
@@ -36,10 +42,32 @@ namespace mlir {
} // namespace mlir
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mesh;
namespace {
-// Create operations converting a linear index to a multi-dimensional index
+/// Convert vec of OpFoldResults (ints) into vector of Values.
+static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+ llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics,
+ Type type = Type()) {
+ SmallVector<Value> values;
+ auto dyn = dynamics.begin();
+ Type i64 = b.getI64Type();
+ if (!type)
+ type = i64;
+ assert(i64 == type || b.getIndexType() == type);
+ for (auto s : statics) {
+ values.emplace_back(
+ ShapedType::isDynamic(s)
+ ? *(dyn++)
+ : b.create<arith::ConstantOp>(loc, type,
+ i64 == type ? b.getI64IntegerAttr(s)
+ : b.getIndexAttr(s)));
+ }
+ return values;
+};
+
+/// Create operations converting a linear index to a multi-dimensional index.
static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
Value linearIndex,
ValueRange dimensions) {
@@ -48,23 +76,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
for (int i = n - 1; i >= 0; --i) {
multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
- if (i > 0) {
+ if (i > 0)
linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
- }
}
return multiIndex;
}
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
ValueRange dimensions) {
- auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
- auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+ Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
for (int i = multiIndex.size() - 1; i >= 0; --i) {
- auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+ Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
}
@@ -72,35 +99,179 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
return linearIndex;
}
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+ if (!shardOp)
+ return failure();
+ auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+ if (!shardingOp)
+ return failure();
+
+ rewriter.replaceOp(op, shardingOp.getResult());
+ return success();
+ }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto splitAxes = op.getSplitAxes().getAxes();
+ int64_t maxNAxes = 0;
+ for (auto axes : splitAxes) {
+ maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+ }
+
+ // To hold the split axes, create empty 2d tensor with shape
+ // {splitAxes.size(), max-size-of-split-groups}.
+ // Set trailing elements for smaller split-groups to -1.
+ Location loc = op.getLoc();
+ auto i16 = rewriter.getI16Type();
+ auto i64 = rewriter.getI64Type();
+ int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+ Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+ auto attr = IntegerAttr::get(i16, 0xffff);
+ Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+ resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+ .getResult(0);
+
+ // explicitly write values into tensor row by row
+ int64_t strides[] = {1, 1};
+ int64_t nSplits = 0;
+ ValueRange empty = {};
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t size = axes.size();
+ if (size > 0)
+ ++nSplits;
+ int64_t offs[] = {(int64_t)i, 0};
+ int64_t sizes[] = {1, size};
+ auto tensorType = RankedTensorType::get({size}, i16);
+ auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+ auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+ resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+ loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+ }
+
+ // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+ // Store the halo sizes in the tensor.
+ auto haloSizes =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+ adaptor.getDynamicHaloSizes());
+ auto type = RankedTensorType::get({nSplits, 2}, i64);
+ Value resHaloSizes =
+ haloSizes.empty()
+ ? rewriter
+ .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+ i64)
+ .getResult()
+ : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+ .getResult();
+
+ // To hold sharded dims offsets, create Tensor with shape {nSplits,
+ // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+ // elements for smaller split-groups to -1. Computing the max size of the
+ // split groups needs using collectiveProcessGroupSize (which needs the
+ // MeshOp)
+ Value resOffsets;
+ if (adaptor.getStaticShardedDimsOffsets().empty()) {
+ resOffsets = rewriter.create<tensor::EmptyOp>(
+ loc, std::array<int64_t, 2>{0, 0}, i64);
+ } else {
+ SymbolTableCollection symbolTableCollection;
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto maxSplitSize = 0;
+ for (auto axes : splitAxes) {
+ int64_t splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ assert(splitSize != ShapedType::kDynamic);
+ maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+ }
+ assert(maxSplitSize);
+ ++maxSplitSize; // add one for the total size
+
+ resOffsets = rewriter.create<tensor::EmptyOp>(
+ loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+ resOffsets =
+ rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+ auto offsets =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+ adaptor.getDynamicShardedDimsOffsets());
+ int64_t curr = 0;
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+ ++splitSize; // add one for the total size
+ ArrayRef<Value> values(&offsets[curr], splitSize);
+ Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+ int64_t offs[] = {(int64_t)i, 0};
+ int64_t sizes[] = {1, splitSize};
+ resOffsets = rewriter.create<tensor::InsertSliceOp>(
+ loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+ curr += splitSize;
+ }
+ }
+
+ // return a tuple of tensors as defined by type converter
+ SmallVector<Type> resTypes;
+ if (failed(getTypeConverter()->convertType(op.getResult().getType(),
+ resTypes)))
+ return failure();
+
+ resSplitAxes =
+ rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+ resHaloSizes =
+ rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
+ resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, TupleType::get(op.getContext(), resTypes),
+ ValueRange{resSplitAxes, resHaloSizes, resOffsets});
+
+ return success();
+ }
+};
+
struct ConvertProcessMultiIndexOp
- : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public OpConversionPattern<ProcessMultiIndexOp> {
+ using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
- mlir::PatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// Currently converts its linear index to a multi-dimensional index.
SymbolTableCollection symbolTableCollection;
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
auto meshOp = getMesh(op, symbolTableCollection);
// For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape())) {
- return mlir::failure();
- }
+ if (ShapedType::isDynamicShape(meshOp.getShape()))
+ return failure();
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
});
- auto rank =
- rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+ Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
// optionally extract subset of mesh axes
- auto axes = op.getAxes();
+ auto axes = adaptor.getAxes();
if (!axes.empty()) {
SmallVector<Value> subIndex;
for (auto axis : axes) {
@@ -110,32 +281,33 @@ struct ConvertProcessMultiIndexOp
}
rewriter.replaceOp(op, mIdx);
- return mlir::success();
+ return success();
}
};
-struct ConvertProcessLinearIndexOp
- : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
- using OpRewritePattern::OpRewritePattern;
-
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
- mlir::PatternRewriter &rewriter) const override {
-
- // Finds a global named "static_mpi_rank" it will use that splat value.
- // Otherwise it defaults to mpi.comm_rank.
-
- auto loc = op.getLoc();
- auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
- if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
- op, rankOpName)) {
- if (auto initTnsr = globalOp.getInitialValueAttr()) {
- auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
- rewriter.replaceOp(op,
- rewriter.create<arith::ConstantIndexOp>(loc, val));
- return mlir::success();
- }
+class ConvertProcessLinearIndexOp
+ : public OpConversionPattern<ProcessLinearIndexOp> {
+ int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
+
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ // Constructor accepting worldRank
+ ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
+ MLIRContext *context, int64_t worldRank_ = -1)
+ : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
+
+ LogicalResult
+ matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
+ return success();
}
+
+ // Otherwise call create mpi::CommRankOp
auto rank =
rewriter
.create<mpi::CommRankOp>(
@@ -144,44 +316,43 @@ struct ConvertProcessLinearIndexOp
.getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
- return mlir::success();
+ return success();
}
};
struct ConvertNeighborsLinearIndicesOp
- : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public OpConversionPattern<NeighborsLinearIndicesOp> {
+ using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
- mlir::PatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// Computes the neighbors indices along a split axis by simply
// adding/subtracting 1 to the current index in that dimension.
// Assigns -1 if neighbor is out of bounds.
- auto axes = op.getSplitAxes();
+ auto axes = adaptor.getSplitAxes();
// For now only single axis sharding is supported
- if (axes.size() != 1) {
- return mlir::failure();
- }
+ if (axes.size() != 1)
+ return failure();
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(op, symbolTableCollection);
- auto mIdx = op.getDevice();
+ auto mIdx = adaptor.getDevice();
auto orgIdx = mIdx[axes[0]];
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
});
- auto dimSz = dims[axes[0]];
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
- auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
- auto atBorder = rewriter.create<arith::CmpIOp>(
+ Value dimSz = dims[axes[0]];
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
+ Value atBorder = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, orgIdx,
- rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
auto down = rewriter.create<scf::IfOp>(
loc, atBorder,
[&](OpBuilder &builder, Location loc) {
@@ -206,23 +377,161 @@ struct ConvertNeighborsLinearIndicesOp
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
- .getResult();
+ rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
builder.create<scf::YieldOp>(
loc, multiToLinearIndex(loc, rewriter, tmp, dims));
});
rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
- return mlir::success();
+ return success();
}
};
-...
[truncated]
|
@tkarna FYI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropped a bunch of nit comments, but I'm not qualified to perform an in-depth functional review of this.
✅ With the latest revision this PR passed the C/C++ code formatter. |
db651b9
to
0696d7b
Compare
Thanks! I followed all your suggestions (except one). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Only two minor remarks.
Co-authored-by: Christian Ulmann <christianulmann@gmail.com>
thanks @tkarna for your review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really cool work! Thanks @fschlimb for the contribution! Looks good from an MPI side imho!
My mistake, it's probably from #124713 |
- do not create MPI operations if no halo exchange is needed - allow returning sharding information through `!mesh.sharding` (gets converted into a tuple of tensors) - lowering `mesh.shard_shape` including fixes to the operation itself - global symbol `static_mpi_rank` replaced by an DLTI attribute (now aligned with MPIToLLVM) - smaller fixes and some minor cleanup --------- Co-authored-by: Christian Ulmann <christianulmann@gmail.com>
!mesh.sharding
(gets converted into a tuple of tensors)mesh.shard_shape
including fixes to the operation itselfstatic_mpi_rank
replaced by an DLTI attribute (now aligned with MPIToLLVM)This is the outcome of my downstream project which can now run (! not only compile to LLVM) the pipeline end-to-end going mesh->MPI->LLVM. I grouped the changes into commits that make reviewing easier (when going commit by commit).