Skip to content

Commit 4093dbf

Browse files
mfaltKristofferC
authored andcommitted
Uniform scaling cat with zero dimensions (#29457)
* Uniform scaling cat with zero dimensions * Missing type assertion in test
1 parent 3b7be23 commit 4093dbf

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

stdlib/LinearAlgebra/src/uniformscaling.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,18 +246,18 @@ promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling}}})
246246
for (f,dim,name) in ((:hcat,1,"rows"), (:vcat,2,"cols"))
247247
@eval begin
248248
function $f(A::Union{AbstractVecOrMat,UniformScaling}...)
249-
n = 0
249+
n = -1
250250
for a in A
251251
if !isa(a, UniformScaling)
252252
@assert !has_offset_axes(a)
253253
na = size(a,$dim)
254-
n > 0 && n != na &&
254+
n >= 0 && n != na &&
255255
throw(DimensionMismatch(string("number of ", $name,
256256
" of each array must match (got ", n, " and ", na, ")")))
257257
n = na
258258
end
259259
end
260-
n == 0 && throw(ArgumentError($("$f of only UniformScaling objects cannot determine the matrix size")))
260+
n == -1 && throw(ArgumentError($("$f of only UniformScaling objects cannot determine the matrix size")))
261261
return $f(promote_to_arrays(fill(n,length(A)),1, promote_to_array_type(A), A...)...)
262262
end
263263
end
@@ -268,20 +268,20 @@ function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScalin
268268
@assert !has_offset_axes(A...)
269269
nr = length(rows)
270270
sum(rows) == length(A) || throw(ArgumentError("mismatch between row sizes and number of arguments"))
271-
n = zeros(Int, length(A))
271+
n = fill(-1, length(A))
272272
needcols = false # whether we also need to infer some sizes from the column count
273273
j = 0
274274
for i = 1:nr # infer UniformScaling sizes from row counts, if possible:
275-
ni = 0 # number of rows in this block-row
275+
ni = -1 # number of rows in this block-row, -1 indicates unknown
276276
for k = 1:rows[i]
277277
if !isa(A[j+k], UniformScaling)
278278
na = size(A[j+k], 1)
279-
ni > 0 && ni != na &&
279+
ni >= 0 && ni != na &&
280280
throw(DimensionMismatch("mismatch in number of rows"))
281281
ni = na
282282
end
283283
end
284-
if ni > 0
284+
if ni >= 0
285285
for k = 1:rows[i]
286286
n[j+k] = ni
287287
end
@@ -291,21 +291,22 @@ function hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScalin
291291
j += rows[i]
292292
end
293293
if needcols # some sizes still unknown, try to infer from column count
294-
nc = j = 0
294+
nc = -1
295+
j = 0
295296
for i = 1:nr
296297
nci = 0
297-
rows[i] > 0 && n[j+1] == 0 && continue # column count unknown in this row
298+
rows[i] > 0 && n[j+1] == -1 && (j += rows[i]; continue)
298299
for k = 1:rows[i]
299300
nci += isa(A[j+k], UniformScaling) ? n[j+k] : size(A[j+k], 2)
300301
end
301-
nc > 0 && nc != nci && throw(DimensionMismatch("mismatch in number of columns"))
302+
nc >= 0 && nc != nci && throw(DimensionMismatch("mismatch in number of columns"))
302303
nc = nci
303304
j += rows[i]
304305
end
305-
nc == 0 && throw(ArgumentError("sizes of UniformScalings could not be inferred"))
306+
nc == -1 && throw(ArgumentError("sizes of UniformScalings could not be inferred"))
306307
j = 0
307308
for i = 1:nr
308-
if rows[i] > 0 && n[j+1] == 0 # this row consists entirely of UniformScalings
309+
if rows[i] > 0 && n[j+1] == -1 # this row consists entirely of UniformScalings
309310
nci = nc ÷ rows[i]
310311
nci * rows[i] != nc && throw(DimensionMismatch("indivisible UniformScaling sizes"))
311312
for k = 1:rows[i]

stdlib/LinearAlgebra/test/uniformscaling.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,26 @@ end
195195
for T in (Matrix, SparseMatrixCSC)
196196
A = T(rand(3,4))
197197
B = T(rand(3,3))
198+
C = T(rand(0,3))
199+
D = T(rand(2,0))
198200
@test (hcat(A, 2I))::T == hcat(A, Matrix(2I, 3, 3))
199201
@test (vcat(A, 2I))::T == vcat(A, Matrix(2I, 4, 4))
202+
@test (hcat(C, 2I))::T == C
203+
@test (vcat(D, 2I))::T == D
200204
@test (hcat(I, 3I, A, 2I))::T == hcat(Matrix(I, 3, 3), Matrix(3I, 3, 3), A, Matrix(2I, 3, 3))
201205
@test (vcat(I, 3I, A, 2I))::T == vcat(Matrix(I, 4, 4), Matrix(3I, 4, 4), A, Matrix(2I, 4, 4))
202206
@test (hvcat((2,1,2), B, 2I, I, 3I, 4I))::T ==
203207
hvcat((2,1,2), B, Matrix(2I, 3, 3), Matrix(I, 6, 6), Matrix(3I, 3, 3), Matrix(4I, 3, 3))
208+
@test hvcat((3,1), C, C, I, 3I)::T == hvcat((2,1), C, C, Matrix(3I, 6,6))
209+
@test hvcat((2,2,2), I, 2I, 3I, 4I, C, C)::T ==
210+
hvcat((2,2,2), Matrix(I, 3, 3), Matrix(2I, 3,3 ), Matrix(3I, 3,3), Matrix(4I, 3,3), C, C)
211+
@test hvcat((2,2,4), C, C, I, 2I, 3I, 4I, 5I, D)::T ==
212+
hvcat((2,2,4), C, C, Matrix(I, 3, 3), Matrix(2I,3,3),
213+
Matrix(3I, 2, 2), Matrix(4I, 2, 2), Matrix(5I,2,2), D)
214+
@test (hvcat((2,3,2), B, 2I, C, C, I, 3I, 4I))::T ==
215+
hvcat((2,2,2), B, Matrix(2I, 3, 3), C, C, Matrix(3I, 3, 3), Matrix(4I, 3, 3))
216+
@test hvcat((3,2,1), C, C, I, B ,3I, 2I)::T ==
217+
hvcat((2,2,1), C, C, B, Matrix(3I,3,3), Matrix(2I,6,6))
204218
end
205219
end
206220

0 commit comments

Comments
 (0)