|  | 
|  | 1 | +# RUN: %PYTHON %s | FileCheck %s | 
|  | 2 | + | 
|  | 3 | +from mlir.ir import * | 
|  | 4 | +from mlir.dialects import shard | 
|  | 5 | +from mlir.dialects import func | 
|  | 6 | + | 
|  | 7 | + | 
|  | 8 | +def constructAndPrintInModule(f): | 
|  | 9 | +    print("\nTEST:", f.__name__) | 
|  | 10 | +    with Context(), Location.unknown(): | 
|  | 11 | +        module = Module.create() | 
|  | 12 | +        with InsertionPoint(module.body): | 
|  | 13 | +            f() | 
|  | 14 | +        print(module) | 
|  | 15 | +    module.operation.verify() | 
|  | 16 | +    return f | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +# CHECK-LABEL: TEST: testShardGrid | 
|  | 20 | +@constructAndPrintInModule | 
|  | 21 | +def testShardGrid(): | 
|  | 22 | +    # Test creating shard grids with different shapes | 
|  | 23 | +    grid2d = shard.GridOp("grid_2d", [2, 2]) | 
|  | 24 | +    grid1d = shard.GridOp("grid_1d", [4]) | 
|  | 25 | + | 
|  | 26 | +    # CHECK: shard.grid @grid_2d(shape = 2x2) | 
|  | 27 | +    # CHECK: shard.grid @grid_1d(shape = 4) | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +# CHECK-LABEL: TEST: testCollectiveOperations | 
|  | 31 | +@constructAndPrintInModule | 
|  | 32 | +def testCollectiveOperations(): | 
|  | 33 | +    # Create grid and types | 
|  | 34 | +    grid_op = shard.GridOp("grid_2x2", [2, 2]) | 
|  | 35 | +    i32 = IntegerType.get_signless(32) | 
|  | 36 | +    index_type = IndexType.get() | 
|  | 37 | +    input_type = RankedTensorType.get([4, 2], i32) | 
|  | 38 | +    gather_result_type = RankedTensorType.get([4, 4], i32) | 
|  | 39 | + | 
|  | 40 | +    # Create a function to hold the operations | 
|  | 41 | +    func_type = FunctionType.get([input_type], [input_type]) | 
|  | 42 | +    test_func = func.FuncOp("test_collectives", func_type) | 
|  | 43 | + | 
|  | 44 | +    with InsertionPoint(test_func.add_entry_block()): | 
|  | 45 | +        arg = test_func.entry_block.arguments[0] | 
|  | 46 | + | 
|  | 47 | +        gather_op = shard.AllGatherOp( | 
|  | 48 | +            input=arg, | 
|  | 49 | +            grid=FlatSymbolRefAttr.get("grid_2x2"), | 
|  | 50 | +            grid_axes=DenseI16ArrayAttr.get([1]), | 
|  | 51 | +            gather_axis=IntegerAttr.get(index_type, 1), | 
|  | 52 | +            result=gather_result_type, | 
|  | 53 | +        ) | 
|  | 54 | + | 
|  | 55 | +        reduce_op = shard.AllReduceOp( | 
|  | 56 | +            input=arg, | 
|  | 57 | +            grid=FlatSymbolRefAttr.get("grid_2x2"), | 
|  | 58 | +            reduction=shard.ReductionKind.Sum, | 
|  | 59 | +            result=input_type, | 
|  | 60 | +        ) | 
|  | 61 | + | 
|  | 62 | +        func.ReturnOp([reduce_op]) | 
|  | 63 | + | 
|  | 64 | +    # CHECK: shard.grid @grid_2x2(shape = 2x2) | 
|  | 65 | +    # CHECK: func.func @test_collectives(%arg0: tensor<4x2xi32>) -> tensor<4x2xi32> | 
|  | 66 | +    # CHECK: %all_gather = shard.all_gather %arg0 on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32> | 
|  | 67 | +    # CHECK: %all_reduce = shard.all_reduce %arg0 on @grid_2x2 : tensor<4x2xi32> -> tensor<4x2xi32> | 
0 commit comments