@@ -57,10 +57,14 @@ function Base.isempty(x::Union{WrappedConcretePJRTArray,WrappedConcreteIFRTArray
57
57
return isempty (ancestor (x))
58
58
end
59
59
60
- function Base. convert (:: Type{<:Array} , X:: ConcretePJRTArray{T,N} ) where {T,N}
61
- if Sharding. is_sharded (X)
62
- data = Array {T,N} (undef, size (X)... )
60
+ function Base. convert (:: Type{<:Array} , X:: AbstractConcreteArray{T,N} ) where {T,N}
61
+ data = Array {T,N} (undef, size (X)... )
62
+ write_to_host_buffer! (data, X)
63
+ return data
64
+ end
63
65
66
+ function write_to_host_buffer! (data:: Array , X:: ConcretePJRTArray{T,N} ) where {T,N}
67
+ if Sharding. is_sharded (X)
64
68
completed = Set {eltype(X.sharding.device_to_array_slices)} ()
65
69
for idx in 1 : length (X. data)
66
70
slice = X. sharding. device_to_array_slices[idx]
@@ -73,19 +77,15 @@ function Base.convert(::Type{<:Array}, X::ConcretePJRTArray{T,N}) where {T,N}
73
77
XLA. to_host (X. data[idx], data_slice, Reactant. Sharding. NoSharding ())
74
78
data[slice... ] .= data_slice
75
79
end
76
-
77
- return data
78
80
else
79
- data = Array {T,N} (undef, size (X)... )
80
81
XLA. to_host (XLA. synced_buffer (only (X. data)), data, Reactant. Sharding. NoSharding ())
81
- return data
82
82
end
83
+ return nothing
83
84
end
84
85
85
- function Base. convert (:: Type{<:Array} , X:: ConcreteIFRTArray{T,N} ) where {T,N}
86
- data = zeros (T, size (X)... )
86
+ function write_to_host_buffer! (data:: Array , X:: ConcreteIFRTArray{T,N} ) where {T,N}
87
87
XLA. to_host (X. data, data, X. sharding)
88
- return data
88
+ return nothing
89
89
end
90
90
91
91
function Base. convert (
@@ -347,7 +347,9 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
347
347
),
348
348
)
349
349
end
350
- aux = copyto! (similar (Array{ElType}, axes (bc)), bc)
350
+ aux = copyto! (
351
+ similar (Array{ElType}, axes (bc)), convert (Broadcast. Broadcasted{Nothing}, bc)
352
+ )
351
353
return ConcretePJRTArray (aux) # XXX : result should be on correct device?
352
354
end
353
355
@@ -367,6 +369,32 @@ function Base.copyto!(dest::AbstractConcreteArray, src::AbstractConcreteArray)
367
369
return dest
368
370
end
369
371
372
+ for aType in (:ConcretePJRTArray , :ConcreteIFRTArray )
373
+ @eval begin
374
+ function Base. copyto! (
375
+ dest:: AbstractConcreteArray ,
376
+ src:: Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}} ,
377
+ )
378
+ dest. data = copy (src). data
379
+ return dest
380
+ end
381
+
382
+ function Base. copyto! (
383
+ dest:: Array , src:: Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}}
384
+ )
385
+ write_to_host_buffer! (dest, copy (src))
386
+ return dest
387
+ end
388
+
389
+ function Base. copyto! (
390
+ dest:: AbstractArray , src:: Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}}
391
+ )
392
+ copyto! (dest, convert (Array, copy (src)))
393
+ return dest
394
+ end
395
+ end
396
+ end
397
+
370
398
Base. collect (x:: AbstractConcreteArray ) = convert (Array, x)
371
399
372
400
function Base. mapreduce (
0 commit comments