diff --git a/stdlib/SharedArrays/src/SharedArrays.jl b/stdlib/SharedArrays/src/SharedArrays.jl index 347d22180f7b5..85f1eb4fff150 100644 --- a/stdlib/SharedArrays/src/SharedArrays.jl +++ b/stdlib/SharedArrays/src/SharedArrays.jl @@ -295,14 +295,21 @@ size(S::SharedArray) = S.dims elsize(::Type{SharedArray{T,N}}) where {T,N} = elsize(Array{T,N}) # aka fieldtype(T, :s) IndexStyle(::Type{<:SharedArray}) = IndexLinear() +function local_array_by_id(refid) + if isa(refid, Future) + refid = remoteref_id(refid) + end + fetch(channel_from_id(refid)) +end + function reshape(a::SharedArray{T}, dims::NTuple{N,Int}) where {T,N} if length(a) != prod(dims) throw(DimensionMismatch("dimensions must be consistent with array size")) end refs = Vector{Future}(undef, length(a.pids)) for (i, p) in enumerate(a.pids) - refs[i] = remotecall(p, a.refs[i], dims) do r,d - reshape(fetch(r),d) + refs[i] = remotecall(p, a.refs[i], dims) do r, d + reshape(local_array_by_id(r), d) end end @@ -382,7 +389,7 @@ function shared_pids(pids) # only use workers on the current host pids = procs(myid()) if length(pids) > 1 - pids = filter(x -> x != 1, pids) + pids = filter(!=(1), pids) end onlocalhost = true @@ -419,13 +426,7 @@ sub_1dim(S::SharedArray, pidx) = view(S.s, range_1dim(S, pidx)) function init_loc_flds(S::SharedArray{T,N}, empty_local=false) where T where N if myid() in S.pids S.pidx = findfirst(isequal(myid()), S.pids) - if isa(S.refs[1], Future) - refid = remoteref_id(S.refs[S.pidx]) - else - refid = S.refs[S.pidx] - end - c = channel_from_id(refid) - S.s = fetch(c) + S.s = local_array_by_id(S.refs[S.pidx]) S.loc_subarr_1d = sub_1dim(S, S.pidx) else S.pidx = 0 diff --git a/stdlib/SharedArrays/test/runtests.jl b/stdlib/SharedArrays/test/runtests.jl index 7a4d46d4777b3..7f1bbb6891ce0 100644 --- a/stdlib/SharedArrays/test/runtests.jl +++ b/stdlib/SharedArrays/test/runtests.jl @@ -176,6 +176,12 @@ d = SharedArrays.shmem_fill(1.0, (10,10,10)) @test fill(1., 100, 10) == reshape(d,(100,10)) d = SharedArrays.shmem_fill(1.0, (10,10,10)) @test_throws DimensionMismatch reshape(d,(50,)) +# issue #40249, reshaping on another process +let m = SharedArray{ComplexF64}(10, 20, 30) + m2 = remotecall_fetch(() -> reshape(m, (100, :)), id_other) + @test size(m2) == (100, 60) + @test m2 isa SharedArray +end # rand, randn d = SharedArrays.shmem_rand(dims)