Skip to content

Commit 82ba809

Browse files
committed
fix: check for ifrt_array_copy_to_host_buffer
1 parent e70ef1d commit 82ba809

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/xla/IFRT/Array.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ function XLA.buffer_on_cpu(::Array)
138138
end
139139

140140
function XLA.to_host(buffer::Array, data, reactant_sharding)
141-
if length(XLA.devices(XLA.sharding(buffer))) == 1
141+
if is_single_device_sharding(XLA.sharding(buffer)) ||
142+
is_fully_replicated(XLA.sharding(buffer))
142143
GC.@preserve buffer data begin
143144
@ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
144145
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}

0 commit comments

Comments
 (0)