Skip to content

Commit 7ad6877

Browse files
dlfivefiftymbauman
authored andcommitted
Remove unnecessary restriction to StridedVecOrMat (#35929)
* Remove unnecessary restriction to `StridedVecOrMat` The "Strided array interface" https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays-1 means that this is useful beyond these types * Update adjtrans.jl * Add tests for adj/trans strides * Add tests, change strides(::Adjoint{<:Any,<:AbstractVector}) definition * stride(::AbstractrArray, k) for all k, add ConjPtr * Remove ConjPtr * Always throw an error if strides is not implemented * Update abstractarray.jl * Update blas.jl * Remove k < 1 special case * Also widen elsize to AbstractVecOrMat * Use strides for dim > ndims * Update stdlib/LinearAlgebra/test/blas.jl Co-authored-by: Matt Bauman <mbauman@gmail.com> (cherry picked from commit 6b2c7f1)
1 parent 24f033c commit 7ad6877

File tree

4 files changed

+65
-17
lines changed

4 files changed

+65
-17
lines changed

base/abstractarray.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,11 @@ julia> stride(A,3)
392392
12
393393
```
394394
"""
395-
stride(A::AbstractArray, k::Integer) = strides(A)[k]
395+
function stride(A::AbstractArray, k::Integer)
396+
st = strides(A)
397+
k ndims(A) && return st[k]
398+
return sum(st .* size(A))
399+
end
396400

397401
@inline size_to_strides(s, d, sz...) = (s, size_to_strides(s * d, sz...)...)
398402
size_to_strides(s, d) = (s,)

stdlib/LinearAlgebra/src/adjtrans.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,17 @@ convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S,
199199
convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(convert(S, A.parent))
200200

201201
# Strides and pointer for transposed strided arrays — but only if the elements are actually stored in memory
202-
Base.strides(A::Adjoint{<:Real, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1))
203-
Base.strides(A::Transpose{<:Any, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1))
202+
Base.strides(A::Adjoint{<:Real, <:AbstractVector}) = (stride(A.parent, 2), stride(A.parent, 1))
203+
Base.strides(A::Transpose{<:Any, <:AbstractVector}) = (stride(A.parent, 2), stride(A.parent, 1))
204204
# For matrices it's slightly faster to use reverse and avoid calling stride twice
205-
Base.strides(A::Adjoint{<:Real, <:StridedMatrix}) = reverse(strides(A.parent))
206-
Base.strides(A::Transpose{<:Any, <:StridedMatrix}) = reverse(strides(A.parent))
205+
Base.strides(A::Adjoint{<:Real, <:AbstractMatrix}) = reverse(strides(A.parent))
206+
Base.strides(A::Transpose{<:Any, <:AbstractMatrix}) = reverse(strides(A.parent))
207207

208-
Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
209-
Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
208+
Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:AbstractVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
209+
Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:AbstractVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
210210

211-
Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:StridedVecOrMat} = Base.elsize(P)
212-
Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:StridedVecOrMat} = Base.elsize(P)
211+
Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:AbstractVecOrMat} = Base.elsize(P)
212+
Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:AbstractVecOrMat} = Base.elsize(P)
213213

214214
# for vectors, the semantics of the wrapped and unwrapped types differ
215215
# so attempt to maintain both the parent and wrapper type insofar as possible

stdlib/LinearAlgebra/test/blas.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,44 @@ Base.setindex!(A::WrappedArray, v, i::Int) = setindex!(A.A, v, i)
458458
Base.setindex!(A::WrappedArray{T, N}, v, I::Vararg{Int, N}) where {T, N} = setindex!(A.A, v, I...)
459459
Base.unsafe_convert(::Type{Ptr{T}}, A::WrappedArray{T}) where T = Base.unsafe_convert(Ptr{T}, A.A)
460460

461-
Base.stride(A::WrappedArray, i::Int) = stride(A.A, i)
461+
Base.strides(A::WrappedArray) = strides(A.A)
462+
463+
@testset "strided interface adjtrans" begin
464+
x = WrappedArray([1, 2, 3, 4])
465+
@test stride(x,1) == 1
466+
@test stride(x,2) == stride(x,3) == 4
467+
@test strides(x') == strides(transpose(x)) == (4,1)
468+
@test pointer(x') == pointer(transpose(x)) == pointer(x)
469+
@test_throws BoundsError stride(x,0)
470+
471+
A = WrappedArray([1 2; 3 4; 5 6])
472+
@test stride(A,1) == 1
473+
@test stride(A,2) == 3
474+
@test stride(A,3) == stride(A,4) >= 6
475+
@test strides(A') == strides(transpose(A)) == (3,1)
476+
@test pointer(A') == pointer(transpose(A)) == pointer(A)
477+
@test_throws BoundsError stride(A,0)
478+
479+
y = WrappedArray([1+im, 2, 3, 4])
480+
@test strides(transpose(y)) == (4,1)
481+
@test pointer(transpose(y)) == pointer(y)
482+
@test_throws MethodError strides(y')
483+
@test_throws ErrorException pointer(y')
484+
485+
B = WrappedArray([1+im 2; 3 4; 5 6])
486+
@test strides(transpose(B)) == (3,1)
487+
@test pointer(transpose(B)) == pointer(B)
488+
@test_throws MethodError strides(B')
489+
@test_throws ErrorException pointer(B')
490+
491+
@test_throws MethodError stride(1:5,0)
492+
@test_throws MethodError stride(1:5,1)
493+
@test_throws MethodError stride(1:5,2)
494+
@test_throws MethodError strides(transpose(1:5))
495+
@test_throws MethodError strides((1:5)')
496+
@test_throws ErrorException pointer(transpose(1:5))
497+
@test_throws ErrorException pointer((1:5)')
498+
end
462499

463500
@testset "strided interface blas" begin
464501
for elty in (Float32, Float64, ComplexF32, ComplexF64)

test/abstractarray.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,36 +1082,43 @@ end
10821082
Ap = Base.PermutedDimsArray(A, perm)
10831083
At = transpose(A)
10841084
Aa = adjoint(A)
1085+
St = transpose(A)
1086+
Sa = adjoint(A)
10851087
Sp = Base.PermutedDimsArray(S, perm)
10861088
Ps = Strider{Int, 2}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
10871089
@test pointer(Ap) == pointer(Sp) == pointer(Ps) == pointer(At) == pointer(Aa)
10881090
for i in 1:length(Ap)
10891091
# This is intentionally disabled due to ambiguity
1090-
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
1091-
@test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
1092-
@test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i]
1092+
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) == pointer(St, i) == pointer(Sa, i)
1093+
@test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) == pointer(St, i) == pointer(Sa, i)
1094+
@test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i] == St[i] == Sa[i]
10931095
end
10941096
Pv = view(P, idxs[collect(perm)]...)
10951097
Apv = view(Ap, idxs[collect(perm)]...)
10961098
Atv = view(At, idxs[collect(perm)]...)
10971099
Ata = view(Aa, idxs[collect(perm)]...)
1100+
Stv = view(St, idxs[collect(perm)]...)
1101+
Sta = view(Sa, idxs[collect(perm)]...)
10981102
Spv = view(Sp, idxs[collect(perm)]...)
10991103
Pvs = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv))
11001104
@test pointer(Apv) == pointer(Spv) == pointer(Pvs) == pointer(Atv) == pointer(Ata)
11011105
for i in 1:length(Apv)
1102-
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i)
1103-
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i]
1106+
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i) == pointer(Stv, i) == pointer(Sta, i)
1107+
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i] == Stv[i] == Sta[i]
11041108
end
11051109
Vp = permutedims(Av, perm)
11061110
Avp = Base.PermutedDimsArray(Av, perm)
11071111
Avt = transpose(Av)
11081112
Ava = adjoint(Av)
1113+
Svt = transpose(Sv)
1114+
Sva = adjoint(Sv)
11091115
Svp = Base.PermutedDimsArray(Sv, perm)
11101116
@test pointer(Avp) == pointer(Svp) == pointer(Avt) == pointer(Ava)
11111117
for i in 1:length(Avp)
11121118
# This is intentionally disabled due to ambiguity
1113-
@test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i)
1114-
@test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i]
1119+
@test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i) == pointer(Svt, i) == pointer(Sva, i)
1120+
@test pointer(Avt, i) == pointer(Ava, i) == pointer(Svt, i) == pointer(Sva, i)
1121+
@test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i] == Svt[i] == Sva[i]
11151122
end
11161123
end
11171124
end

0 commit comments

Comments
 (0)