Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interpreter for AllGatherOp #1727

Merged
merged 3 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -742,19 +742,21 @@ Afterwards, within each `process_group`:
```mlir
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1.0, 2.0], [3.0, 4.0]]
// %operand@(1, 0): [[5.0, 6.0], [7.0, 8.0]]
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xf32>) -> tensor<2x4xf32>
// %result@(0, 0): [[1.0, 2.0, 5.0, 6.0], [3.0, 4.0, 7.0, 8.0]]
// %result@(1, 0): [[1.0, 2.0, 5.0, 6.0], [3.0, 4.0, 7.0, 8.0]]
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
```

&nbsp;[More Examples](../stablehlo/tests/interpret_all_gather.mlir)

### all_reduce

#### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ one of the following tracking labels.
| abs | yes | yes | yes | yes | yes |
| add | yes | yes | yes | yes | yes |
| after_all | yes | yes | yes | yes | yes |
| all_gather | yes | revisit | no | no | no |
| all_gather | yes | revisit | no | no | yes |
| all_reduce | yes | revisit | yes | no | yes |
| all_to_all | yes | revisit | yes | no | no |
| and | yes | yes | yes | yes | yes |
Expand Down
8 changes: 6 additions & 2 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -877,9 +877,13 @@ void AllToAllOp::build(OpBuilder& odsBuilder, OperationState& odsState,
//===----------------------------------------------------------------------===//

LogicalResult AllGatherOp::verify() {
int64_t channelId = 0;
ghpvnist marked this conversation as resolved.
Show resolved Hide resolved
if (auto channelHandleAttr = getChannelHandleAttr())
channelId = channelHandleAttr.getHandle();

return hlo::verifyAllGatherOp(getLoc(), getOperand(), getAllGatherDim(),
getReplicaGroups(), getUseGlobalDeviceIds(),
getResult());
getReplicaGroups(), channelId,
getUseGlobalDeviceIds(), getResult());
}

//===----------------------------------------------------------------------===//
Expand Down
21 changes: 10 additions & 11 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,8 @@ def StableHLO_WhileOp: StableHLO_Op<"while", [
let hasCustomAssemblyFormat = 1;
}

def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", [SameOperandsAndResultElementType]> {
def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",
[SameOperandsAndResultElementType] /*all_gather_c6*/> {
string summary = "AllGather operation";
string description = [{
Within each process group in the process grid, concatenates the values of the
Expand All @@ -1311,20 +1312,18 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", [SameOperandsAndResultEle
```mlir
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>,
// use_global_device_ids = false
} : (tensor<2x2xf32>) -> tensor<2x4xf32>
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
```
}];

let arguments = (ins
HLO_Tensor:$operand,
I64Attr:$all_gather_dim,
I64ElementsAttr:$replica_groups,
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle,
UnitAttr:$use_global_device_ids
HLO_Tensor:$operand, /*all_gather_i1*/
I64Attr:$all_gather_dim, /*all_gather_i2*/
I64ElementsAttr:$replica_groups, /*all_gather_i3*/
OptionalAttr<StableHLO_ChannelHandle>:$channel_handle, /*all_gather_i4*/
UnitAttr:$use_global_device_ids /*all_gather_i5*/
);
let results = (outs HLO_Tensor);
let hasVerifier = 1;
Expand Down
36 changes: 26 additions & 10 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ LogicalResult verifyReplicaGroups(std::optional<Location> location,
std::optional<size_t> expectedGroupSize) {
auto replicaGroupType = replicaGroups.getType().cast<RankedTensorType>();

// all_gather_i3
if (replicaGroupType.getRank() != 2)
return emitOptionalError(location,
"replica groups should be a rank 2 tensor");
Expand All @@ -435,18 +436,19 @@ LogicalResult verifyReplicaGroups(std::optional<Location> location,
for (int64_t replicaId : replicaIds) {
// Replica groups are stored in a 2D tensor. If the op supports non-uniform
// groups, null replica IDs are stored as -1.
// all_gather_c4
if (replicaId == -1) {
if (!allGroupsMustHaveSameSize) continue;
return emitOptionalError(location, "Invalid replica id -1");
}

// all_reduce_c1
// all_gather_c2, all_reduce_c1
if (!replicaIdsSeen.insert(replicaId).second)
return emitOptionalError(location, "replica id #", replicaId,
" seen more than once");
}

// all_reduce_c3
// all_gather_c4, all_reduce_c3
for (size_t id = 0; id < replicaIdsSeen.size(); id++)
if (!replicaIdsSeen.contains(id))
return emitOptionalError(location, "replica id #", id,
Expand Down Expand Up @@ -3040,37 +3042,50 @@ LogicalResult inferWhileOp(std::optional<Location>, ValueRange operand,
LogicalResult verifyAllGatherOp(std::optional<Location> location, Value operand,
int64_t allGatherDim,
DenseIntElementsAttr replicaGroups,
bool useGlobalDeviceIds, Value result) {
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/true,
useGlobalDeviceIds,
/*expectedGroupSize=*/std::nullopt)))
return failure();

int64_t channelId, bool useGlobalDeviceIds,
Value result) {
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
auto resultType = result.getType().dyn_cast<RankedTensorType>();

// all_gather_c1
if (allGatherDim < 0)
return emitOptionalError(location, "all_gather_dim cannot be negative");

if (operandType) {
// all_gather_c1
if (allGatherDim >= operandType.getRank())
return emitOptionalError(
location, "all_gather_dim must be a valid index of operand");

// TODO(#1745): Sync verification of AllGather with HLO.
if (operandType.getDimSize(allGatherDim) == 0)
return emitOptionalError(
location,
"dimension size of operand at 'all_gather_dim' cannot be zero");
}

// all_gather_i3, all_gather_c2, all_gather_c4
if (failed(verifyReplicaGroups(location, replicaGroups,
/*allGroupsMustHaveSameSize=*/true,
useGlobalDeviceIds,
/*expectedGroupSize=*/std::nullopt)))
return failure();

// all_gather_c5
if (useGlobalDeviceIds && channelId < 0)
return emitOptionalError(
location,
"channel_id cannot be negative when useGlobalDeviceIds is set");

// all_gather_c6
if (operandType && resultType) {
if (resultType.getRank() != operandType.getRank())
return emitOptionalError(location,
"operand and return must have the same rank");
"operand and result must have the same rank");

for (int64_t i = 0; i < operandType.getRank(); i++) {
if (i == allGatherDim) continue;
// all_gather_c6
if (!verifyCompatibleDims(resultType.getDimSize(i),
operandType.getDimSize(i)))
return emitOptionalError(
Expand All @@ -3083,6 +3098,7 @@ LogicalResult verifyAllGatherOp(std::optional<Location> location, Value operand,
resultType.isDynamicDim(allGatherDim))
return success();

// all_gather_c6
if ((resultType.getDimSize(allGatherDim) %
operandType.getDimSize(allGatherDim)) != 0)
return emitOptionalError(
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ LogicalResult inferWhileOp(std::optional<Location> location, ValueRange operand,
LogicalResult verifyAllGatherOp(std::optional<Location> location, Value operand,
int64_t allGatherDim,
DenseIntElementsAttr replicaGroups,
bool useGlobalDeviceIds, Value result);
int64_t channelId, bool useGlobalDeviceIds,
Value result);

LogicalResult verifyAllReduceOp(std::optional<Location> location, Value operand,
DenseIntElementsAttr replicaGroups,
Expand Down
47 changes: 47 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,25 @@ SmallVector<InterpreterValue> eval(
auto inputs = scope.findTokens(afterAllOp.getInputs());
auto result = evalAfterAllOp(inputs, afterAllOp->getContext());
scope.add(afterAllOp.getResult(), result);
} else if (auto allGatherOp = dyn_cast<AllGatherOp>(op)) {
auto operand = scope.findTensor(allGatherOp.getOperand());

auto replicaGroupsAttr = allGatherOp.getReplicaGroups();
auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape();
SmallVector<SmallVector<uint32_t>> replicaGroups(replicaGroupsShape[0]);
auto replicaGroupsIt = replicaGroupsAttr.getValues<int64_t>().begin();
for (auto &replicaGroup : replicaGroups)
for (auto i = 0; i < replicaGroupsShape[1]; ++i, ++replicaGroupsIt)
replicaGroup.push_back(*replicaGroupsIt);

ChannelId channelId = 0;
if (auto channelHandle = allGatherOp.getChannelHandleAttr())
channelId = channelHandle.getHandle();

auto result = evalAllGatherOp(
operand, allGatherOp.getAllGatherDim(), replicaGroups, channelId,
allGatherOp.getUseGlobalDeviceIds(), process, allGatherOp.getType());
scope.add(allGatherOp.getResult(), result);
} else if (auto allReduceOp = dyn_cast<AllReduceOp>(op)) {
auto operand = scope.findTensor(allReduceOp.getOperand());

Expand Down Expand Up @@ -745,6 +764,34 @@ Token evalAfterAllOp(ArrayRef<Token> inputs, MLIRContext *context) {
return Token(context);
}

Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim,
SmallVector<SmallVector<uint32_t>> replicaGroups,
ChannelId channelId, bool useGlobalDeviceIds,
Process *process, ShapedType resultType) {
if (!process)
llvm::report_fatal_error(
"all_gather is only supported when run via interpreter.run_parallel");

ProcessGroups processGroups;
if (channelId <= 0 && !useGlobalDeviceIds)
processGroups = process->crossReplica(replicaGroups);
if (channelId > 0 && !useGlobalDeviceIds)
processGroups = process->crossReplicaAndPartition(replicaGroups);
if (channelId > 0 && useGlobalDeviceIds)
processGroups = process->flattenedIds(replicaGroups);

auto processGroup = processGroups.findGroup(process->getId());
if (!processGroup)
llvm::report_fatal_error(invalidArgument(
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));

auto groupOperands =
process->rendezvous(*processGroup, channelId, operand).getSortedTensors();

return evalConcatenateOp(groupOperands, allGatherDim, resultType);
}

Tensor evalAllReduceOp(const Tensor &operand,
SmallVector<SmallVector<uint32_t>> replicaGroups,
ChannelId channelId, bool useGlobalDeviceIds,
Expand Down
4 changes: 4 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ namespace stablehlo {
Tensor evalAbsOp(const Tensor &operand, ShapedType resultType);
Tensor evalAddOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType);
Token evalAfterAllOp(ArrayRef<Token> inputs, MLIRContext *context);
Tensor evalAllGatherOp(const Tensor &operand, int64_t allGatherDim,
SmallVector<SmallVector<uint32_t>> replicaGroups,
ChannelId channelId, bool useGlobalDeviceIds,
Process *process, ShapedType resultType);
Tensor evalAllReduceOp(const Tensor &operand,
SmallVector<SmallVector<uint32_t>> replicaGroups,
ChannelId channelId, bool useGlobalDeviceIds,
Expand Down
74 changes: 74 additions & 0 deletions stablehlo/tests/interpret_all_gather.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
%result = "stablehlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
return %result : tensor<2x4xi64>
}
func.func public @main() {
%0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%results:2 = "interpreter.run_parallel"(%0, %1) {
programs=[["all_gather"], ["all_gather"]]
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
check.expect_eq_const %results#0, dense<[[1, 2, 5, 6],
[3, 4, 7, 8]]> : tensor<2x4xi64>
check.expect_eq_const %results#1, dense<[[1, 2, 5, 6],
[3, 4, 7, 8]]> : tensor<2x4xi64>
func.return
}
}

// -----

module @cross_replica_and_partition {
func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
%result = "stablehlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
return %result : tensor<2x4xi64>
}
func.func public @main() {
%0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%results:2 = "interpreter.run_parallel"(%0, %1) {
programs=[["all_gather"], ["all_gather"]]
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
check.expect_eq_const %results#0, dense<[[1, 2, 5, 6],
[3, 4, 7, 8]]> : tensor<2x4xi64>
check.expect_eq_const %results#1, dense<[[1, 2, 5, 6],
[3, 4, 7, 8]]> : tensor<2x4xi64>
func.return
}
}

// -----

module @flattened_ids {
func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
%result = "stablehlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>,
use_global_device_ids
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
return %result : tensor<2x4xi64>
}
func.func public @main() {
%0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%results:2 = "interpreter.run_parallel"(%0, %1) {
programs=[["all_gather"], ["all_gather"]]
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
check.expect_eq_const %results#0, dense<[[1, 2, 5, 6],
[3, 4, 7, 8]]> : tensor<2x4xi64>
check.expect_eq_const %results#1, dense<[[1, 2, 5, 6],
[3, 4, 7, 8]]> : tensor<2x4xi64>
func.return
}
}
Loading
Loading