Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,21 @@ defmodule EXLA do
The metadata is:

* `:key` - the compilation key for debugging

## Sharding

EXLA supports sharding, which is a way to partition a computation across multiple devices.
There are a number of collective operations that are supported by sharding.

### [`all_gather`](https://openxla.org/stablehlo/spec#all_gather)

#### Options

* `:all_gather_dim` - the dimension along which to gather
* `:replica_groups` - 2D list defining how replicas are grouped
* `:use_global_device_ids` - Whether to use global device IDs (default: `false`)
* `:channel_id` - Channel ID for communication (optional)

"""

@behaviour Nx.Defn.Compiler
Expand Down
21 changes: 21 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,27 @@ defmodule EXLA.Defn do
EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type)
end

## to_operator collective ops

defp to_operator(:all_gather, [%Value{} = tensor, opts], ans, _state) do
all_gather_dim = Keyword.fetch!(opts, :all_gather_dim)
replica_groups = Keyword.fetch!(opts, :replica_groups)
use_global_device_ids = Keyword.get(opts, :use_global_device_ids, false)

# We might want to surface all_gather as an operation that takes a container of operands instead of a single one.
[result] =
Value.all_gather(
[tensor],
expr_to_typespec(ans),
all_gather_dim,
replica_groups,
use_global_device_ids,
opts[:channel_id]
)

result
end

defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do
n = opts[:length]
axis = opts[:axis]
Expand Down
23 changes: 23 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ defmodule EXLA.MLIR.Value do
end
end

def all_gather([%Value{function: func} | _] = operands, typespec, all_gather_dim, replica_groups, use_global_device_ids, channel_id \\ nil) do
result_types = typespecs_to_mlir_types([typespec])

num_groups = length(replica_groups)
group_size = if num_groups > 0, do: length(hd(replica_groups)), else: 0
flat_groups = List.flatten(replica_groups)

attributes = [
all_gather_dim: attr_i64(all_gather_dim),
replica_groups: attr_dense_elements(flat_groups, {:s, 64}, {num_groups, group_size}),
use_global_device_ids: attr_boolean(use_global_device_ids)
]

attributes =
if channel_id do
Keyword.put(attributes, :channel_id, attr_i64(channel_id))
else
attributes
end

op(func, "stablehlo.all_gather", operands, result_types, attributes: attributes)
end

defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do
%{type: lhs_type} = get_typespec(lhs)
%{type: rhs_type} = get_typespec(rhs)
Expand Down
73 changes: 72 additions & 1 deletion exla/test/exla/defn/sharding_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ defmodule EXLA.Defn.ShardingTest do
describe "MLIR module generation with sharding" do
@moduletag :multi_device
test "generates correct MLIR with simple 2D mesh and sharding" do
fun = fn x, y -> Nx.add(x, y) end
fun = fn x, y -> Nx.add(x, y)
end

mesh = %Mesh{name: "mesh", shape: {2, 2}}
# First arg: shard dim 0 on mesh axis 0, dim 1 on mesh axis 1
Expand Down Expand Up @@ -737,5 +738,75 @@ defmodule EXLA.Defn.ShardingTest do
assert result.mlir_module =~ ~r/"axis_0"/
assert result.mlir_module =~ ~r/"axis_1"/
end

@moduletag :multi_device
test "generates correct MLIR with all_gather" do
fun = fn x, y -> Nx.add(x, y)
|> Nx.Defn.Kernel.all_gather(all_gather_dim: 0, replica_groups: [[0]])
|> Nx.Defn.Kernel.all_gather(all_gather_dim: 1, replica_groups: [[0]])
end

mesh = %Mesh{name: "mesh", shape: {2, 2}}
# First arg: shard dim 0 on mesh axis 0, dim 1 on mesh axis 1
# Second arg: shard dim 0 on mesh axis 0, dim 1 not sharded
input_shardings = [%{0 => [0], 1 => [1]}, %{0 => [0]}]

# For mesh {2, 2}, we have 4 partitions
# Each partition gets a shard of the inputs
# First input: shape {8, 2} sharded as [[0], [1]] -> each partition gets {4, 1}
# Second input: shape {8, 1} sharded as [[0], []] -> each partition gets {4, 1}
args = [
# partition 0
[Nx.iota({4, 1}), Nx.iota({4, 1})],
# partition 1
[Nx.iota({4, 1}), Nx.iota({4, 1})],
# partition 2
[Nx.iota({4, 1}), Nx.iota({4, 1})],
# partition 3
[Nx.iota({4, 1}), Nx.iota({4, 1})]
]

result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings)

expected_mlir = """
module {
sdy.mesh @mesh = <["axis_0"=2, "axis_1"=2]>
func.func public @main(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {"axis_1", ?}p0]>}, %arg1: tensor<8x1xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", ?}p0, {?}p0]>}) -> tensor<8x2xi32> {
%0 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1] : (tensor<8x1xi32>) -> tensor<8x2xi32>
%1 = stablehlo.add %arg0, %0 : tensor<8x2xi32>
%2 = "stablehlo.all_gather"(%1) <{all_gather_dim = 0 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32>
%3 = "stablehlo.all_gather"(%2) <{all_gather_dim = 1 : i64, replica_groups = dense<0> : tensor<1x1xi64>}> : (tensor<8x2xi32>) -> tensor<8x2xi32>
return %3 : tensor<8x2xi32>
}
}
"""

assert expected_mlir == result.mlir_module

results = EXLA.shard_jit(fun, mesh, input_shardings: input_shardings).(args)

assert length(results) == 4

# After all_gather on both dims, each partition has the full tensor: add(iota, iota) -> 2*iota
# Each shard had iota({4,1}) = [[0],[1],[2],[3]], so add gives [[0],[2],[4],[6]]
# After gathering: replicated 8x2 with pattern [[0,0],[2,2],[4,4],[6,6],[0,0],[2,2],[4,4],[6,6]]
expected_result =
Nx.tensor([
[0, 0],
[2, 2],
[4, 4],
[6, 6],
[0, 0],
[2, 2],
[4, 4],
[6, 6]
])

for r <- results do
assert_equal(r, expected_result)
end
end


end
end
27 changes: 27 additions & 0 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,33 @@ defmodule Nx.Defn.Expr do
expr(out, context, :gather, [tensor, indices, opts])
end

def all_gather(tensor, opts) do
{[tensor], context} = to_exprs([tensor])

_all_gather_dim = opts[:all_gather_dim]
replica_groups = opts[:replica_groups]

# Calculate group size (number of replicas per group)
_group_size =
case replica_groups do
[first_group | _] -> length(first_group)
[] -> 1
end

# Calculate output shape by multiplying the gather dimension by group_size
input_shape = tensor.shape
output_shape =
input_shape
# |> Tuple.to_list()
# |> List.update_at(all_gather_dim, &(&1 * group_size))
# |> List.to_tuple()

# Create output tensor with the new shape
Comment on lines +1172 to +1190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few unused values here due to the stray comments that should all be removed. Also, just pass tensor as out directly

out = %{tensor | shape: output_shape}

expr(out, context, :all_gather, [tensor, opts])
end

@impl true
def reverse(out, tensor, axes) do
tensor = to_expr(tensor)
Expand Down
18 changes: 18 additions & 0 deletions nx/lib/nx/defn/kernel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,24 @@ defmodule Nx.Defn.Kernel do
end
end

@doc """
Gathers tensors from all replicas along a specified dimension.

This operation concatenates tensors from multiple replicas/devices along
the specified dimension. Requires a backend that supports multi-device operations.

## Parameters

* `tensor` - The input tensor to gather

* `opts` - Optional keyword list. These are backend- and compiler-specific;
see your backend or compiler docs for supported options.

"""
def all_gather(tensor, opts \\ []) do
Nx.Defn.Expr.all_gather(tensor, opts)
end

@definitions (Module.definitions_in(__MODULE__, :def) ++
Module.definitions_in(__MODULE__, :defmacro)) --
[
Expand Down
13 changes: 13 additions & 0 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2952,4 +2952,17 @@ defmodule Nx.DefnTest do
assert vectorized_metadata_tuple(x, z) == vec_nonvec_result
end
end

describe "sharding" do
defn all_gather_test(tensor) do
Nx.Defn.Kernel.all_gather(tensor, all_gather_dim: 0, replica_groups: [[0]])
end

@tag compiler: Evaluator
test "all_gather works" do
assert_raise UndefinedFunctionError, fn ->
all_gather_test(Nx.tensor([1, 2, 3, 4]))
end
end
end
end
Loading