Skip to content

Commit aad7245

Browse files
authored
make one(::AbstractMatrix) use similar instead of zeros (#54162)
1 parent 2aa55e2 commit aad7245

File tree

3 files changed

+34
-15
lines changed

3 files changed

+34
-15
lines changed

base/abstractarray.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,23 @@ zero(x::AbstractArray{T}) where {T<:Number} = fill!(similar(x, typeof(zero(T))),
12211221
zero(x::AbstractArray{S}) where {S<:Union{Missing, Number}} = fill!(similar(x, typeof(zero(S))), zero(S))
12221222
zero(x::AbstractArray) = map(zero, x)
12231223

1224+
function _one(unit::T, mat::AbstractMatrix) where {T}
1225+
(rows, cols) = axes(mat)
1226+
(length(rows) == length(cols)) ||
1227+
throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
1228+
zer = zero(unit)::T
1229+
require_one_based_indexing(mat)
1230+
I = similar(mat, T)
1231+
fill!(I, zer)
1232+
for i rows
1233+
I[i, i] = unit
1234+
end
1235+
I
1236+
end
1237+
1238+
one(x::AbstractMatrix{T}) where {T} = _one(one(T), x)
1239+
oneunit(x::AbstractMatrix{T}) where {T} = _one(oneunit(T), x)
1240+
12241241
## iteration support for arrays by iterating over `eachindex` in the array ##
12251242
# Allows fast iteration by default for both IndexLinear and IndexCartesian arrays
12261243

base/array.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -598,21 +598,6 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
598598
end
599599
end
600600

601-
function _one(unit::T, x::AbstractMatrix) where T
602-
require_one_based_indexing(x)
603-
m,n = size(x)
604-
m==n || throw(DimensionMismatch("multiplicative identity defined only for square matrices"))
605-
# Matrix{T}(I, m, m)
606-
I = zeros(T, m, m)
607-
for i in 1:m
608-
I[i,i] = unit
609-
end
610-
I
611-
end
612-
613-
one(x::AbstractMatrix{T}) where {T} = _one(one(T), x)
614-
oneunit(x::AbstractMatrix{T}) where {T} = _one(oneunit(T), x)
615-
616601
## Conversions ##
617602

618603
convert(::Type{T}, a::AbstractArray) where {T<:Array} = a isa T ? a : T(a)::T

test/abstractarray.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,3 +2119,20 @@ end
21192119
end
21202120
end
21212121
end
2122+
2123+
@testset "one" begin
2124+
@test one([1 2; 3 4]) == [1 0; 0 1]
2125+
@test one([1 2; 3 4]) isa Matrix{Int}
2126+
2127+
struct Mat <: AbstractMatrix{Int}
2128+
p::Matrix{Int}
2129+
end
2130+
Base.size(m::Mat) = size(m.p)
2131+
Base.IndexStyle(::Type{<:Mat}) = IndexLinear()
2132+
Base.getindex(m::Mat, i::Int) = m.p[i]
2133+
Base.setindex!(m::Mat, v, i::Int) = m.p[i] = v
2134+
Base.similar(::Mat, ::Type{Int}, size::NTuple{2,Int}) = Mat(Matrix{Int}(undef, size))
2135+
2136+
@test one(Mat([1 2; 3 4])) == Mat([1 0; 0 1])
2137+
@test one(Mat([1 2; 3 4])) isa Mat
2138+
end

0 commit comments

Comments
 (0)