-
Notifications
You must be signed in to change notification settings - Fork 213
Basic all_gather implementation #1663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
exla/lib/exla/defn.ex
Outdated
| Value.all_gather( | ||
| [tensor], | ||
| expr_to_typespec(ans), | ||
| all_gather_dim, | ||
| replica_groups, | ||
| use_global_device_ids, | ||
| Keyword.take(opts, [:channel_id]) | ||
| ) | ||
| |> hd() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's hard match for now instead of hd (i.e. [result] = Value...)
And then add a comment that we might want to surface all_gather as an operation that takes a container of operands instead of a single one.
exla/lib/exla/mlir/value.ex
Outdated
|
|
||
| attributes = | ||
| if opts[:channel_id] do | ||
| attributes ++ [channel_id: attr_i64(opts[:channel_id])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use Keyword.put instead of ++
exla/lib/exla/mlir/value.ex
Outdated
| if opts[:channel_id] do | ||
| attributes ++ [channel_id: attr_i64(opts[:channel_id])] | ||
| else | ||
| attributes end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formatting
exla/lib/exla/mlir/value.ex
Outdated
| end | ||
| end | ||
|
|
||
| def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, opts \\ []) do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about channel_id being a required argument and we just pass the value directly?
nx/lib/nx/defn/evaluator.ex
Outdated
| if op == :all_gather and not function_exported?(mod, :all_gather, 3) do | ||
| raise ArgumentError, | ||
| "all_gather/3 is not supported by backend #{inspect(mod)}." | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove this, do we have a test verifying this raise? Also, I believe this is already checked elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's not, it seems to me that this check should be more general
| _all_gather_dim = opts[:all_gather_dim] | ||
| replica_groups = opts[:replica_groups] | ||
|
|
||
| # Calculate group size (number of replicas per group) | ||
| _group_size = | ||
| case replica_groups do | ||
| [first_group | _] -> length(first_group) | ||
| [] -> 1 | ||
| end | ||
|
|
||
| # Calculate output shape by multiplying the gather dimension by group_size | ||
| input_shape = tensor.shape | ||
| output_shape = | ||
| input_shape | ||
| # |> Tuple.to_list() | ||
| # |> List.update_at(all_gather_dim, &(&1 * group_size)) | ||
| # |> List.to_tuple() | ||
|
|
||
| # Create output tensor with the new shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few unused values here due to the stray comments that should all be removed. Also, just pass tensor as out directly
nx/lib/nx/defn/kernel.ex
Outdated
|
|
||
| * `tensor` - The input tensor to gather | ||
| * `all_gather_dim` - The dimension along which to gather | ||
| * `replica_groups` - 2D list defining how replicas are grouped (required) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this is the terminology we want to surface here. For now, let's make the function all_gather(tensor, opts) and defer the documentation of opts to the specific backend or compiler.
And in EXLA we should add a new section to the moduledoc of EXLA describing Sharding
polvalente
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great! I think we need more tests in both Nx and EXLA
Implements Nx.Defn.Kernel.all_gather/2 to gather sharded tensor data across mesh partitions during distributed execution.
Changes
Nx
Add all_gather/2 in defn/kernel.ex and defn/expr.ex with sharding semantics
Add evaluator support for all_gather in defn/evaluator.ex
EXLA
Lower all_gather to stablehlo.all_gather in defn.ex and mlir/value.ex
Test
EXLA.Defn.ShardingTest: "generates correct MLIR with all_gather" checks MLIR generation and shard_jit output across a 2×2 mesh along axis 0 and 1