Skip to content

Commit 5129b37

Browse files
authored
[MLIR][Python] Add shard Dialect Python Bindings (#162578)
Add Python bindings for `shard` dialect. Provide means for creating constructs in this dialect in Python.
1 parent 073335d commit 5129b37

File tree

5 files changed

+128
-0
lines changed

5 files changed

+128
-0
lines changed

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,15 @@ declare_mlir_dialect_python_bindings(
346346
dialects/memref.py
347347
DIALECT_NAME memref)
348348

349+
declare_mlir_dialect_python_bindings(
350+
ADD_TO_PARENT MLIRPythonSources.Dialects
351+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
352+
TD_FILE dialects/ShardOps.td
353+
SOURCES
354+
dialects/shard.py
355+
DIALECT_NAME shard
356+
GEN_ENUM_BINDINGS)
357+
349358
declare_mlir_dialect_python_bindings(
350359
ADD_TO_PARENT MLIRPythonSources.Dialects
351360
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- ShardOps.td - Entry point for ShardOps bindings ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_SHARD_OPS
10+
#define PYTHON_BINDINGS_SHARD_OPS
11+
12+
include "mlir/Dialect/Shard/IR/ShardOps.td"
13+
14+
#endif

mlir/python/mlir/dialects/shard.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._shard_ops_gen import *
6+
from ._shard_enum_gen import *

mlir/test/python/dialects/shard.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import shard
5+
from mlir.dialects import func
6+
7+
8+
def constructAndPrintInModule(f):
9+
print("\nTEST:", f.__name__)
10+
with Context(), Location.unknown():
11+
module = Module.create()
12+
with InsertionPoint(module.body):
13+
f()
14+
print(module)
15+
module.operation.verify()
16+
return f
17+
18+
19+
# CHECK-LABEL: TEST: testShardGrid
20+
@constructAndPrintInModule
21+
def testShardGrid():
22+
# Test creating shard grids with different shapes
23+
grid2d = shard.GridOp("grid_2d", [2, 2])
24+
grid1d = shard.GridOp("grid_1d", [4])
25+
26+
# CHECK: shard.grid @grid_2d(shape = 2x2)
27+
# CHECK: shard.grid @grid_1d(shape = 4)
28+
29+
30+
# CHECK-LABEL: TEST: testCollectiveOperations
31+
@constructAndPrintInModule
32+
def testCollectiveOperations():
33+
# Create grid and types
34+
grid_op = shard.GridOp("grid_2x2", [2, 2])
35+
i32 = IntegerType.get_signless(32)
36+
index_type = IndexType.get()
37+
input_type = RankedTensorType.get([4, 2], i32)
38+
gather_result_type = RankedTensorType.get([4, 4], i32)
39+
40+
# Create a function to hold the operations
41+
func_type = FunctionType.get([input_type], [input_type])
42+
test_func = func.FuncOp("test_collectives", func_type)
43+
44+
with InsertionPoint(test_func.add_entry_block()):
45+
arg = test_func.entry_block.arguments[0]
46+
47+
gather_op = shard.AllGatherOp(
48+
input=arg,
49+
grid=FlatSymbolRefAttr.get("grid_2x2"),
50+
grid_axes=DenseI16ArrayAttr.get([1]),
51+
gather_axis=IntegerAttr.get(index_type, 1),
52+
result=gather_result_type,
53+
)
54+
55+
reduce_op = shard.AllReduceOp(
56+
input=arg,
57+
grid=FlatSymbolRefAttr.get("grid_2x2"),
58+
reduction=shard.ReductionKind.Sum,
59+
result=input_type,
60+
)
61+
62+
func.ReturnOp([reduce_op])
63+
64+
# CHECK: shard.grid @grid_2x2(shape = 2x2)
65+
# CHECK: func.func @test_collectives(%arg0: tensor<4x2xi32>) -> tensor<4x2xi32>
66+
# CHECK: %all_gather = shard.all_gather %arg0 on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
67+
# CHECK: %all_reduce = shard.all_reduce %arg0 on @grid_2x2 : tensor<4x2xi32> -> tensor<4x2xi32>

utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,38 @@ filegroup(
10091009
],
10101010
)
10111011

1012+
##---------------------------------------------------------------------------##
1013+
# Shard dialect.
1014+
##---------------------------------------------------------------------------##
1015+
1016+
gentbl_filegroup(
1017+
name = "ShardOpsPyGen",
1018+
tbl_outs = {
1019+
"mlir/dialects/_shard_enum_gen.py": [
1020+
"-gen-python-enum-bindings",
1021+
"-bind-dialect=shard",
1022+
],
1023+
"mlir/dialects/_shard_ops_gen.py": [
1024+
"-gen-python-op-bindings",
1025+
"-bind-dialect=shard",
1026+
],
1027+
},
1028+
tblgen = "//mlir:mlir-tblgen",
1029+
td_file = "mlir/dialects/ShardOps.td",
1030+
deps = [
1031+
"//mlir:OpBaseTdFiles",
1032+
"//mlir:ShardTdFiles",
1033+
],
1034+
)
1035+
1036+
filegroup(
1037+
name = "ShardOpsPyFiles",
1038+
srcs = [
1039+
"mlir/dialects/shard.py",
1040+
":ShardOpsPyGen",
1041+
],
1042+
)
1043+
10121044
##---------------------------------------------------------------------------##
10131045
# Shape dialect.
10141046
##---------------------------------------------------------------------------##

0 commit comments

Comments
 (0)