Skip to content

Commit 6151de7

Browse files
committed
fix: check for ifrt_array_copy_to_host_buffer
1 parent f1b97bb commit 6151de7

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/xla/IFRT/Array.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ function XLA.buffer_on_cpu(::Array)
138138
end
139139

140140
function XLA.to_host(buffer::Array, data, reactant_sharding)
141-
if length(XLA.devices(XLA.sharding(buffer))) == 1
141+
reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding)
142+
143+
if is_single_device_sharding(XLA.sharding(buffer)) ||
144+
is_fully_replicated(XLA.sharding(buffer)) ||
145+
reactant_sharding isa Reactant.Sharding.NoSharding
142146
GC.@preserve buffer data begin
143147
@ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
144148
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
@@ -147,7 +151,6 @@ function XLA.to_host(buffer::Array, data, reactant_sharding)
147151
return data
148152
end
149153

150-
reactant_sharding = Reactant.Sharding.unwrap_shardinfo(reactant_sharding)
151154
@assert reactant_sharding isa Reactant.Sharding.HloSharding
152155
client = XLA.client(buffer)
153156
all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids)

0 commit comments

Comments
 (0)