Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interpreter for AllGatherOp #1727

Merged
merged 3 commits into from
Aug 19, 2023
Merged

Add interpreter for AllGatherOp #1727

merged 3 commits into from
Aug 19, 2023

Conversation

ghpvnist
Copy link
Member

@ghpvnist ghpvnist commented Aug 9, 2023

We have the following constraints in the spec:

(I1) `operand`: tensor or per-tensor quantized tensor.
(I2) `all_gather_dim`: constant of type `si64`.
(I3) `replica_groups`: 2-dimensional tensor constant of type `si64`.
(I4) `channel_id`: constant of type `si64`.
(I5) `use_global_device_ids`: constant of type `i1`.
(C1) `0 <= all_gather_dim < rank(operand)`.
(C2) `is_unique(replica_groups)`.
(C3) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used.
(C4) `0 <= replica_groups < size(replica_groups)`.
(C5) If `use_global_device_ids = true`, then `channel_id > 0`.
(C6) `type(result) = type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.

These constraints will be comprehensively covered by the following tests:

I1: a) `operand` is not a tensor. (Covered by ODS).
I2: a) `all_gather_dim` is not a constant of type `si64`. (Covered by ODS).
I3: a) `replica_groups` is not a 2-dimensional tensor constant.
I3: b) `element_type(replica_groups) != si64`. (Covered by ODS).
I4: a) `channel_id` is not a constant of type `si64`. (Covered by ODS).
I5: a) `use_global_device_ids` is not a constant of type `i1`. (Covered by ODS).
C1: a) `all_gather_dim < 0`.
C1: b) `all_gather_dim >= rank(operand)`.
C2: a) `is_unique(replica_groups) = false`.
C3: a) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used. 
C4: a) `replica_groups < 0`.
C4: b) `replica_groups >= size(replica_groups)`.
C5: a) `use_global_device_ids = true` and `channel_id <= 0`.
C6: a) `type(result) != type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.

If we drop the "Covered by ODS" pieces, this will leave us with the following test cases:

I3: a) `replica_groups` is not a 2-dimensional tensor constant.
C1: a) `all_gather_dim < 0`.
C1: b) `all_gather_dim >= rank(operand)`.
C2: a) `is_unique(replica_groups) = false`.
C3: a) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used. 
C4: a) `replica_groups < 0`.
C4: b) `replica_groups >= size(replica_groups)`.
C5: a) `use_global_device_ids = true` and `channel_id <= 0`.
C6: a) `type(result) != type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.

Notes:

  • C3a verification is infeasible since num_replicas and num_partitions are not known statically.
  • C4b verification is infeasible since size(replica_groups) is not known statically.

closes #1118

@ghpvnist ghpvnist added Interpreter Migrate to MHLO PR that needs to be migrated to MLIR-HLO labels Aug 9, 2023
@ghpvnist ghpvnist requested a review from burmako August 9, 2023 23:53
@ghpvnist ghpvnist assigned burmako and ghpvnist and unassigned burmako Aug 9, 2023
@ghpvnist ghpvnist marked this pull request as draft August 10, 2023 00:02
@ghpvnist ghpvnist assigned burmako and unassigned ghpvnist Aug 10, 2023
@ghpvnist ghpvnist marked this pull request as ready for review August 10, 2023 00:08
@ghpvnist ghpvnist assigned ghpvnist and unassigned burmako Aug 10, 2023
@ghpvnist ghpvnist marked this pull request as draft August 10, 2023 01:18
@ghpvnist ghpvnist assigned burmako and unassigned ghpvnist Aug 10, 2023
@ghpvnist ghpvnist marked this pull request as ready for review August 10, 2023 18:12
@ghpvnist ghpvnist force-pushed the all_gather branch 5 times, most recently from 94fc83c to 16e0cec Compare August 12, 2023 07:05
stablehlo/dialect/StablehloOps.cpp Show resolved Hide resolved
docs/spec.md Outdated Show resolved Hide resolved
@burmako burmako assigned ghpvnist and unassigned burmako Aug 18, 2023
@ghpvnist ghpvnist assigned burmako and unassigned ghpvnist Aug 18, 2023
stablehlo/dialect/TypeInference.cpp Outdated Show resolved Hide resolved
@burmako burmako assigned ghpvnist and unassigned burmako Aug 18, 2023
@ghpvnist ghpvnist merged commit 2a38fde into openxla:main Aug 19, 2023
7 checks passed
@ghpvnist ghpvnist deleted the all_gather branch August 19, 2023 00:06
penagos pushed a commit to penagos/stablehlo that referenced this pull request Sep 21, 2023
We have the following constraints in the spec:

```
(I1) `operand`: tensor or per-tensor quantized tensor.
(I2) `all_gather_dim`: constant of type `si64`.
(I3) `replica_groups`: 2-dimensional tensor constant of type `si64`.
(I4) `channel_id`: constant of type `si64`.
(I5) `use_global_device_ids`: constant of type `i1`.
(C1) `0 <= all_gather_dim < rank(operand)`.
(C2) `is_unique(replica_groups)`.
(C3) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used.
(C4) `0 <= replica_groups < size(replica_groups)`.
(C5) If `use_global_device_ids = true`, then `channel_id > 0`.
(C6) `type(result) = type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) `operand` is not a tensor. (Covered by ODS).
I2: a) `all_gather_dim` is not a constant of type `si64`. (Covered by ODS).
I3: a) `replica_groups` is not a 2-dimensional tensor constant.
I3: b) `element_type(replica_groups) != si64`. (Covered by ODS).
I4: a) `channel_id` is not a constant of type `si64`. (Covered by ODS).
I5: a) `use_global_device_ids` is not a constant of type `i1`. (Covered by ODS).
C1: a) `all_gather_dim < 0`.
C1: b) `all_gather_dim >= rank(operand)`.
C2: a) `is_unique(replica_groups) = false`.
C3: a) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used. 
C4: a) `replica_groups < 0`.
C4: b) `replica_groups >= size(replica_groups)`.
C5: a) `use_global_device_ids = true` and `channel_id <= 0`.
C6: a) `type(result) != type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
I3: a) `replica_groups` is not a 2-dimensional tensor constant.
C1: a) `all_gather_dim < 0`.
C1: b) `all_gather_dim >= rank(operand)`.
C2: a) `is_unique(replica_groups) = false`.
C3: a) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used. 
C4: a) `replica_groups < 0`.
C4: b) `replica_groups >= size(replica_groups)`.
C5: a) `use_global_device_ids = true` and `channel_id <= 0`.
C6: a) `type(result) != type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.
```

Notes:
* C3a verification is infeasible since `num_replicas` and
`num_partitions` are not known statically.
* C4b verification is infeasible since `size(replica_groups)` is not
known statically.

closes openxla#1118
penagos pushed a commit to penagos/stablehlo that referenced this pull request Sep 21, 2023
We have the following constraints in the spec:

```
(I1) `operand`: tensor or per-tensor quantized tensor.
(I2) `all_gather_dim`: constant of type `si64`.
(I3) `replica_groups`: 2-dimensional tensor constant of type `si64`.
(I4) `channel_id`: constant of type `si64`.
(I5) `use_global_device_ids`: constant of type `i1`.
(C1) `0 <= all_gather_dim < rank(operand)`.
(C2) `is_unique(replica_groups)`.
(C3) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used.
(C4) `0 <= replica_groups < size(replica_groups)`.
(C5) If `use_global_device_ids = true`, then `channel_id > 0`.
(C6) `type(result) = type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) `operand` is not a tensor. (Covered by ODS).
I2: a) `all_gather_dim` is not a constant of type `si64`. (Covered by ODS).
I3: a) `replica_groups` is not a 2-dimensional tensor constant.
I3: b) `element_type(replica_groups) != si64`. (Covered by ODS).
I4: a) `channel_id` is not a constant of type `si64`. (Covered by ODS).
I5: a) `use_global_device_ids` is not a constant of type `i1`. (Covered by ODS).
C1: a) `all_gather_dim < 0`.
C1: b) `all_gather_dim >= rank(operand)`.
C2: a) `is_unique(replica_groups) = false`.
C3: a) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used. 
C4: a) `replica_groups < 0`.
C4: b) `replica_groups >= size(replica_groups)`.
C5: a) `use_global_device_ids = true` and `channel_id <= 0`.
C6: a) `type(result) != type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
I3: a) `replica_groups` is not a 2-dimensional tensor constant.
C1: a) `all_gather_dim < 0`.
C1: b) `all_gather_dim >= rank(operand)`.
C2: a) `is_unique(replica_groups) = false`.
C3: a) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used. 
C4: a) `replica_groups < 0`.
C4: b) `replica_groups >= size(replica_groups)`.
C5: a) `use_global_device_ids = true` and `channel_id <= 0`.
C6: a) `type(result) != type(operand)` except:
* `dim(result, all_gather_dim) =
  dim(operand, all_gather_dim) * dim(process_groups, 1)`.
```

Notes:
* C3a verification is infeasible since `num_replicas` and
`num_partitions` are not known statically.
* C4b verification is infeasible since `size(replica_groups)` is not
known statically.

closes openxla#1118
@ghpvnist ghpvnist added the Spec label Sep 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Interpreter Migrate to MHLO PR that needs to be migrated to MLIR-HLO Spec
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Add interpreter for AllGatherOp
2 participants