Skip to content

Commit 97f0993

Browse files
authored
feat: broadcasting of concrete arrays (#913)
1 parent 8d3b81a commit 97f0993

File tree

2 files changed

+58
-11
lines changed

2 files changed

+58
-11
lines changed

src/ConcreteRArray.jl

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,14 @@ function Base.isempty(x::Union{WrappedConcretePJRTArray,WrappedConcreteIFRTArray
5757
return isempty(ancestor(x))
5858
end
5959

60-
function Base.convert(::Type{<:Array}, X::ConcretePJRTArray{T,N}) where {T,N}
61-
if Sharding.is_sharded(X)
62-
data = Array{T,N}(undef, size(X)...)
60+
function Base.convert(::Type{<:Array}, X::AbstractConcreteArray{T,N}) where {T,N}
61+
data = Array{T,N}(undef, size(X)...)
62+
write_to_host_buffer!(data, X)
63+
return data
64+
end
6365

66+
function write_to_host_buffer!(data::Array, X::ConcretePJRTArray{T,N}) where {T,N}
67+
if Sharding.is_sharded(X)
6468
completed = Set{eltype(X.sharding.device_to_array_slices)}()
6569
for idx in 1:length(X.data)
6670
slice = X.sharding.device_to_array_slices[idx]
@@ -73,19 +77,15 @@ function Base.convert(::Type{<:Array}, X::ConcretePJRTArray{T,N}) where {T,N}
7377
XLA.to_host(X.data[idx], data_slice, Reactant.Sharding.NoSharding())
7478
data[slice...] .= data_slice
7579
end
76-
77-
return data
7880
else
79-
data = Array{T,N}(undef, size(X)...)
8081
XLA.to_host(XLA.synced_buffer(only(X.data)), data, Reactant.Sharding.NoSharding())
81-
return data
8282
end
83+
return nothing
8384
end
8485

85-
function Base.convert(::Type{<:Array}, X::ConcreteIFRTArray{T,N}) where {T,N}
86-
data = zeros(T, size(X)...)
86+
function write_to_host_buffer!(data::Array, X::ConcreteIFRTArray{T,N}) where {T,N}
8787
XLA.to_host(X.data, data, X.sharding)
88-
return data
88+
return nothing
8989
end
9090

9191
function Base.convert(
@@ -347,7 +347,9 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
347347
),
348348
)
349349
end
350-
aux = copyto!(similar(Array{ElType}, axes(bc)), bc)
350+
aux = copyto!(
351+
similar(Array{ElType}, axes(bc)), convert(Broadcast.Broadcasted{Nothing}, bc)
352+
)
351353
return ConcretePJRTArray(aux) # XXX: result should be on correct device?
352354
end
353355

@@ -367,6 +369,32 @@ function Base.copyto!(dest::AbstractConcreteArray, src::AbstractConcreteArray)
367369
return dest
368370
end
369371

372+
for aType in (:ConcretePJRTArray, :ConcreteIFRTArray)
373+
@eval begin
374+
function Base.copyto!(
375+
dest::AbstractConcreteArray,
376+
src::Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}},
377+
)
378+
dest.data = copy(src).data
379+
return dest
380+
end
381+
382+
function Base.copyto!(
383+
dest::Array, src::Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}}
384+
)
385+
write_to_host_buffer!(dest, copy(src))
386+
return dest
387+
end
388+
389+
function Base.copyto!(
390+
dest::AbstractArray, src::Broadcast.Broadcasted{Broadcast.ArrayStyle{$(aType)}}
391+
)
392+
copyto!(dest, convert(Array, copy(src)))
393+
return dest
394+
end
395+
end
396+
end
397+
370398
Base.collect(x::AbstractConcreteArray) = convert(Array, x)
371399

372400
function Base.mapreduce(

test/basic.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,3 +961,22 @@ end
961961
@test Array(@jit(map!(abs2, y_ra, x_ra))) map!(abs2, y, x)
962962
@test Array(y_ra) y
963963
end
964+
965+
@testset "ConcreteRArray inplace broadcast" begin
966+
x = Reactant.to_rarray(zeros(Float32, 2, 3))
967+
y = Reactant.to_rarray(reshape(collect(Float32, 1:6), 2, 3))
968+
969+
x .= y ./ 2
970+
971+
@test Array(x) Array(y) ./ 2
972+
973+
x = zeros(Float32, 2, 3)
974+
x .= y ./ 2
975+
976+
@test Array(x) Array(y) ./ 2
977+
978+
x = view(zeros(Float32, 2, 5), :, 1:3)
979+
x .= y ./ 2
980+
981+
@test Array(x) Array(y) ./ 2
982+
end

0 commit comments

Comments
 (0)