Skip to content

Conversation

@Chapaman
Copy link
Contributor

Implements Nx.Defn.Kernel.all_gather/2 to gather sharded tensor data across mesh partitions during distributed execution.
Changes
Nx
Add all_gather/2 in defn/kernel.ex and defn/expr.ex with sharding semantics
Add evaluator support for all_gather in defn/evaluator.ex
EXLA
Lower all_gather to stablehlo.all_gather in defn.ex and mlir/value.ex
Test
EXLA.Defn.ShardingTest: "generates correct MLIR with all_gather" checks MLIR generation and shard_jit output across a 2×2 mesh along axis 0 and 1

@Chapaman Chapaman changed the title Basic gall_gather implementation Basic all_gather implementation Jan 30, 2026
Comment on lines 1481 to 1489
Value.all_gather(
[tensor],
expr_to_typespec(ans),
all_gather_dim,
replica_groups,
use_global_device_ids,
Keyword.take(opts, [:channel_id])
)
|> hd()
Copy link
Contributor

@polvalente polvalente Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hard match for now instead of hd (i.e. [result] = Value...)
And then add a comment that we might want to surface all_gather as an operation that takes a container of operands instead of a single one.


attributes =
if opts[:channel_id] do
attributes ++ [channel_id: attr_i64(opts[:channel_id])]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use Keyword.put instead of ++

if opts[:channel_id] do
attributes ++ [channel_id: attr_i64(opts[:channel_id])]
else
attributes end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting

end
end

def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, opts \\ []) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about channel_id being a required argument and we just pass the value directly?

Comment on lines 481 to 484
if op == :all_gather and not function_exported?(mod, :all_gather, 3) do
raise ArgumentError,
"all_gather/3 is not supported by backend #{inspect(mod)}."
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove this, do we have a test verifying this raise? Also, I believe this is already checked elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not, it seems to me that this check should be more general

Comment on lines +1172 to +1190
_all_gather_dim = opts[:all_gather_dim]
replica_groups = opts[:replica_groups]

# Calculate group size (number of replicas per group)
_group_size =
case replica_groups do
[first_group | _] -> length(first_group)
[] -> 1
end

# Calculate output shape by multiplying the gather dimension by group_size
input_shape = tensor.shape
output_shape =
input_shape
# |> Tuple.to_list()
# |> List.update_at(all_gather_dim, &(&1 * group_size))
# |> List.to_tuple()

# Create output tensor with the new shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few unused values here due to the stray comments that should all be removed. Also, just pass tensor as out directly


* `tensor` - The input tensor to gather
* `all_gather_dim` - The dimension along which to gather
* `replica_groups` - 2D list defining how replicas are grouped (required)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is the terminology we want to surface here. For now, let's make the function all_gather(tensor, opts) and defer the documentation of opts to the specific backend or compiler.

And in EXLA we should add a new section to the moduledoc of EXLA describing Sharding

Copy link
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great! I think we need more tests in both Nx and EXLA

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants