Skip to content

Commit

Permalink
feat: fix generalized indexing (Nemocas#1585)
Browse files Browse the repository at this point in the history
- make A[i::Int, js::Array] return a Vector and support views
  • Loading branch information
thofma authored and ooinaruhugh committed Feb 15, 2024
1 parent 45ee11c commit c88cc49
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 9 deletions.
28 changes: 24 additions & 4 deletions src/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ end

_checkbounds(i::Int, j::Int) = 1 <= j <= i

function _checkbounds(A, i::Int, j::Int)
_checkbounds(i::Int, j::AbstractVector{Int}) = all(jj -> 1 <= jj <= i, j)

function _checkbounds(A, i::Union{Int, AbstractVector{Int}}, j::Union{Int, AbstractVector{Int}})
(_checkbounds(nrows(A), i) && _checkbounds(ncols(A), j)) ||
Base.throw_boundserror(A, (i, j))
end
Expand Down Expand Up @@ -386,18 +388,36 @@ function getindex(M::MatElem, rows::AbstractVector{Int}, cols::AbstractVector{In
return A
end

function getindex(M::MatElem, i::Int, cols::AbstractVector{Int})
_checkbounds(M, i, cols)
A = Vector{elem_type(base_ring(M))}(undef, length(cols))
for j in eachindex(cols)
A[j] = deepcopy(M[i, cols[j]])
end
return A
end

function getindex(M::MatElem, rows::AbstractVector{Int}, j::Int)
_checkbounds(M, rows, j)
A = Vector{elem_type(base_ring(M))}(undef, length(rows))
for i in eachindex(rows)
A[i] = deepcopy(M[rows[i], j])
end
return A
end

getindex(M::MatElem,
rows::Union{Int,Colon,AbstractVector{Int}},
cols::Union{Int,Colon,AbstractVector{Int}}) = M[_to_indices(M, rows, cols)...]

function _to_indices(x, rows, cols)
if rows isa Integer
rows = rows:rows
rows = rows
elseif rows isa Colon
rows = 1:nrows(x)
end
if cols isa Integer
cols = cols:cols
cols = cols
elseif cols isa Colon
cols = 1:ncols(x)
end
Expand Down Expand Up @@ -2519,7 +2539,7 @@ function trace_of_prod(M::MatElem, N::MatElem)
is_square(M) && is_square(N) || error("Not a square matrix in trace")
d = zero(base_ring(M))
for i = 1:nrows(M)
d += (M[i, :] * N[:, i])[1, 1]
d += (M[i:i, :] * N[:, i:i])[1, 1]
end
return d
end
Expand Down
5 changes: 5 additions & 0 deletions src/generic/GenericTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,11 @@ struct MatSpaceView{T <: NCRingElement, V, W} <: Mat{T}
base_ring::NCRing
end

struct MatSpaceVecView{T <: NCRingElement, V, W} <: AbstractVector{T}
entries::SubArray{T, 1, Matrix{T}, V, W}
base_ring::NCRing
end

###############################################################################
#
# MatRing / MatRingElem
Expand Down
19 changes: 18 additions & 1 deletion src/generic/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,18 @@ function deepcopy_internal(d::MatSpaceView{T}, dict::IdDict) where T <: NCRingEl
return MatSpaceView(deepcopy_internal(d.entries, dict), d.base_ring)
end

function Base.view(M::Mat{T}, rows::AbstractUnitRange{Int}, cols::AbstractUnitRange{Int}) where T <: NCRingElement
function Base.view(M::Mat{T}, rows::Union{Colon, AbstractVector{Int}}, cols::Union{Colon, AbstractVector{Int}}) where T <: NCRingElement
return MatSpaceView(view(M.entries, rows, cols), M.base_ring)
end

function Base.view(M::Mat{T}, rows::Int, cols::Union{Colon, AbstractVector{Int}}) where T <: NCRingElement
return MatSpaceVecView(view(M.entries, rows, cols), M.base_ring)
end

function Base.view(M::Mat{T}, rows::Union{Colon, AbstractVector{Int}}, cols::Int) where T <: NCRingElement
return MatSpaceVecView(view(M.entries, rows, cols), M.base_ring)
end

################################################################################
#
# Size, axes and is_square
Expand Down Expand Up @@ -228,3 +236,12 @@ function AbstractAlgebra.mul!(A::Mat{T}, B::Mat{T}, C::Mat{T}, f::Bool = false)
return A
end

Base.length(V::MatSpaceVecView) = length(V.entries)

Base.getindex(V::MatSpaceVecView, i::Int) = V.entries[i]

Base.setindex!(V::MatSpaceVecView{T}, z::T, i::Int) where {T} = (V.entries[i] = z)

Base.setindex!(V::MatSpaceVecView, z::RingElement, i::Int) = setindex!(V.entries, V.base_ring(z), i)

Base.size(V::MatSpaceVecView) = (length(V.entries), )
22 changes: 18 additions & 4 deletions test/generic/Matrix-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1420,8 +1420,8 @@ end
Q = inv(P)

PA = P*A
@test PA == reduce(vcat, [A[Q[i], :] for i in 1:nrows(A)])
@test PA == reduce(vcat, A[Q[i], :] for i in 1:nrows(A))
@test PA == reduce(vcat, [A[Q[i]:Q[i], :] for i in 1:nrows(A)])
@test PA == reduce(vcat, A[Q[i]:Q[i], :] for i in 1:nrows(A))
@test PA == S(reduce(vcat, A.entries[Q[i], :] for i in 1:nrows(A)))
@test A == Q*(P*A)
end
Expand Down Expand Up @@ -4022,20 +4022,34 @@ end
@test fflu(N3) == fflu(M) # tests that deepcopy is correct
@test M2 == M

for i in [ 1, 1:2, : ], j in [ 1, 1:2, : ]
for i in [ 1:1, 1:2, : ], j in [ 1:1, 1:2, : ]
v = @view M[i,j]
@test v isa Generic.MatSpaceView
@test M[i,j] == v
end

M2 = deepcopy(M)
M3 = @view M2[2, 1:2]
@test length(M3) == 2
@test M3 == [2, 3]
M3[2] = 5
@test M2 == ZZ[1 2 3; 2 5 4; 3 4 5]

M2 = deepcopy(M)
M3 = @view M2[1:3, 3]
@test length(M3) == 3
@test M3 == [3, 4, 5]
M3[1] = 10
@test M2 == ZZ[1 2 10; 2 3 4; 3 4 5]

# Test views over noncommutative ring
R = matrix_ring(ZZ, 2)

S = matrix_space(R, 4, 4)

M = rand(S, -10:10)

for i in [ 1, 1:2, : ], j in [ 1, 1:2, : ]
for i in [ 1:1, 1:2, : ], j in [ 1:1, 1:2, : ]
v = @view M[i,j]
@test v isa Generic.MatSpaceView
@test M[i,j] == v
Expand Down

0 comments on commit c88cc49

Please sign in to comment.