@@ -57,10 +57,14 @@ function Base.isempty(x::Union{WrappedConcretePJRTArray,WrappedConcreteIFRTArray
5757 return isempty (ancestor (x))
5858end
5959
60- function Base. convert (:: Type{<:Array} , X:: ConcretePJRTArray{T,N} ) where {T,N}
61- if Sharding. is_sharded (X)
62- data = Array {T,N} (undef, size (X)... )
60+ function Base. convert (:: Type{<:Array} , X:: AbstractConcreteArray{T,N} ) where {T,N}
61+ data = Array {T,N} (undef, size (X)... )
62+ write_to_host_buffer! (data, X)
63+ return data
64+ end
6365
66+ function write_to_host_buffer! (data:: Array , X:: ConcretePJRTArray{T,N} ) where {T,N}
67+ if Sharding. is_sharded (X)
6468 completed = Set {eltype(X.sharding.device_to_array_slices)} ()
6569 for idx in 1 : length (X. data)
6670 slice = X. sharding. device_to_array_slices[idx]
@@ -73,19 +77,15 @@ function Base.convert(::Type{<:Array}, X::ConcretePJRTArray{T,N}) where {T,N}
7377 XLA. to_host (X. data[idx], data_slice, Reactant. Sharding. NoSharding ())
7478 data[slice... ] .= data_slice
7579 end
76-
77- return data
7880 else
79- data = Array {T,N} (undef, size (X)... )
8081 XLA. to_host (XLA. synced_buffer (only (X. data)), data, Reactant. Sharding. NoSharding ())
81- return data
8282 end
83+ return nothing
8384end
8485
85- function Base. convert (:: Type{<:Array} , X:: ConcreteIFRTArray{T,N} ) where {T,N}
86- data = zeros (T, size (X)... )
86+ function write_to_host_buffer! (data:: Array , X:: ConcreteIFRTArray{T,N} ) where {T,N}
8787 XLA. to_host (X. data, data, X. sharding)
88- return data
88+ return nothing
8989end
9090
9191function Base. convert (
@@ -347,7 +347,9 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
347347 ),
348348 )
349349 end
350- aux = copyto! (similar (Array{ElType}, axes (bc)), bc)
350+ aux = copyto! (
351+ similar (Array{ElType}, axes (bc)), convert (Broadcast. Broadcasted{Nothing}, bc)
352+ )
351353 return ConcretePJRTArray (aux) # XXX : result should be on correct device?
352354 end
353355
@@ -367,6 +369,32 @@ function Base.copyto!(dest::AbstractConcreteArray, src::AbstractConcreteArray)
367369 return dest
368370end
369371
372+ for aType in (:ConcretePJRTArray , :ConcreteIFRTArray )
373+ @eval begin
374+ function Base. copyto! (
375+ dest:: AbstractConcreteArray ,
376+ src:: Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}} ,
377+ )
378+ dest. data = copy (src). data
379+ return dest
380+ end
381+
382+ function Base. copyto! (
383+ dest:: Array , src:: Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}}
384+ )
385+ write_to_host_buffer! (dest, copy (src))
386+ return dest
387+ end
388+
389+ function Base. copyto! (
390+ dest:: AbstractArray , src:: Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}}
391+ )
392+ copyto! (dest, convert (Array, copy (src)))
393+ return dest
394+ end
395+ end
396+ end
397+
370398Base. collect (x:: AbstractConcreteArray ) = convert (Array, x)
371399
372400function Base. mapreduce (
0 commit comments