Skip to content

Commit

Permalink
fix #40249, reshaping SharedArray on another process (#40286)
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffBezanson authored Apr 6, 2021
1 parent e34a904 commit 79e198b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
21 changes: 11 additions & 10 deletions stdlib/SharedArrays/src/SharedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,21 @@ size(S::SharedArray) = S.dims
elsize(::Type{SharedArray{T,N}}) where {T,N} = elsize(Array{T,N}) # aka fieldtype(T, :s)
IndexStyle(::Type{<:SharedArray}) = IndexLinear()

function local_array_by_id(refid)
if isa(refid, Future)
refid = remoteref_id(refid)
end
fetch(channel_from_id(refid))
end

function reshape(a::SharedArray{T}, dims::NTuple{N,Int}) where {T,N}
if length(a) != prod(dims)
throw(DimensionMismatch("dimensions must be consistent with array size"))
end
refs = Vector{Future}(undef, length(a.pids))
for (i, p) in enumerate(a.pids)
refs[i] = remotecall(p, a.refs[i], dims) do r,d
reshape(fetch(r),d)
refs[i] = remotecall(p, a.refs[i], dims) do r, d
reshape(local_array_by_id(r), d)
end
end

Expand Down Expand Up @@ -382,7 +389,7 @@ function shared_pids(pids)
# only use workers on the current host
pids = procs(myid())
if length(pids) > 1
pids = filter(x -> x != 1, pids)
pids = filter(!=(1), pids)
end

onlocalhost = true
Expand Down Expand Up @@ -419,13 +426,7 @@ sub_1dim(S::SharedArray, pidx) = view(S.s, range_1dim(S, pidx))
function init_loc_flds(S::SharedArray{T,N}, empty_local=false) where T where N
if myid() in S.pids
S.pidx = findfirst(isequal(myid()), S.pids)
if isa(S.refs[1], Future)
refid = remoteref_id(S.refs[S.pidx])
else
refid = S.refs[S.pidx]
end
c = channel_from_id(refid)
S.s = fetch(c)
S.s = local_array_by_id(S.refs[S.pidx])
S.loc_subarr_1d = sub_1dim(S, S.pidx)
else
S.pidx = 0
Expand Down
6 changes: 6 additions & 0 deletions stdlib/SharedArrays/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ d = SharedArrays.shmem_fill(1.0, (10,10,10))
@test fill(1., 100, 10) == reshape(d,(100,10))
d = SharedArrays.shmem_fill(1.0, (10,10,10))
@test_throws DimensionMismatch reshape(d,(50,))
# issue #40249, reshaping on another process
let m = SharedArray{ComplexF64}(10, 20, 30)
m2 = remotecall_fetch(() -> reshape(m, (100, :)), id_other)
@test size(m2) == (100, 60)
@test m2 isa SharedArray
end

# rand, randn
d = SharedArrays.shmem_rand(dims)
Expand Down

0 comments on commit 79e198b

Please sign in to comment.