Skip to content

Commit 8728a90

Browse files
committed
fix #40249, reshaping SharedArray on another process
1 parent 637f52b commit 8728a90

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

stdlib/SharedArrays/src/SharedArrays.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,21 @@ size(S::SharedArray) = S.dims
295295
elsize(::Type{SharedArray{T,N}}) where {T,N} = elsize(Array{T,N}) # aka fieldtype(T, :s)
296296
IndexStyle(::Type{<:SharedArray}) = IndexLinear()
297297

298+
function local_array_by_id(refid)
299+
if isa(refid, Future)
300+
refid = remoteref_id(refid)
301+
end
302+
fetch(channel_from_id(refid))
303+
end
304+
298305
function reshape(a::SharedArray{T}, dims::NTuple{N,Int}) where {T,N}
299306
if length(a) != prod(dims)
300307
throw(DimensionMismatch("dimensions must be consistent with array size"))
301308
end
302309
refs = Vector{Future}(undef, length(a.pids))
303310
for (i, p) in enumerate(a.pids)
304-
refs[i] = remotecall(p, a.refs[i], dims) do r,d
305-
reshape(fetch(r),d)
311+
refs[i] = remotecall(p, a.refs[i], dims) do r, d
312+
reshape(local_array_by_id(r), d)
306313
end
307314
end
308315

@@ -382,7 +389,7 @@ function shared_pids(pids)
382389
# only use workers on the current host
383390
pids = procs(myid())
384391
if length(pids) > 1
385-
pids = filter(x -> x != 1, pids)
392+
pids = filter(!=(1), pids)
386393
end
387394

388395
onlocalhost = true
@@ -419,13 +426,7 @@ sub_1dim(S::SharedArray, pidx) = view(S.s, range_1dim(S, pidx))
419426
function init_loc_flds(S::SharedArray{T,N}, empty_local=false) where T where N
420427
if myid() in S.pids
421428
S.pidx = findfirst(isequal(myid()), S.pids)
422-
if isa(S.refs[1], Future)
423-
refid = remoteref_id(S.refs[S.pidx])
424-
else
425-
refid = S.refs[S.pidx]
426-
end
427-
c = channel_from_id(refid)
428-
S.s = fetch(c)
429+
S.s = local_array_by_id(S.refs[S.pidx])
429430
S.loc_subarr_1d = sub_1dim(S, S.pidx)
430431
else
431432
S.pidx = 0

stdlib/SharedArrays/test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,12 @@ d = SharedArrays.shmem_fill(1.0, (10,10,10))
176176
@test fill(1., 100, 10) == reshape(d,(100,10))
177177
d = SharedArrays.shmem_fill(1.0, (10,10,10))
178178
@test_throws DimensionMismatch reshape(d,(50,))
179+
# issue #40249, reshaping on another process
180+
let m = SharedArray{ComplexF64}(10, 20, 30)
181+
m2 = remotecall_fetch(() -> reshape(m, (100, :)), id_other)
182+
@test size(m2) == (100, 60)
183+
@test m2 isa SharedArray
184+
end
179185

180186
# rand, randn
181187
d = SharedArrays.shmem_rand(dims)

0 commit comments

Comments
 (0)