Skip to content

Commit

Permalink
Workaround for promote_type for GPU arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed Jul 19, 2024
1 parent ea4cb6c commit 88f0cbc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
31 changes: 31 additions & 0 deletions src/LinearOperatorCollection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,35 @@ include("SamplingOp.jl")
include("NormalOp.jl")
include("DiagOp.jl")

function promote_storage_types(A, B)
A_type = storage_type(A)
B_type = storage_type(B)
S = promote_type(A_type, B_type)
if !isconcretetype(S)
# Find common eltype
elType = promote_type(eltype(A), eltype(B))
if !isconcretetype(elType)
throw(LinearOperatorException("Storage types cannot be promoted to a concrete type"))
end

# Same base type
A_base = Base.typename(A_type).wrapper
B_base = Base.typename(B_type).wrapper
if A_base != B_base
throw(LinearOperatorException("Storage types cannot be promoted to a common base type"))
end

# LinearOperators only accepts DataTypes, so we cant just do A_base{elType}, since that might be a UnionAll
# Check if either A_type or B_type have the fitting eltype
if eltype(A_type) == elType
S = A_type
elseif eltype(B_type) == elType
S = B_type
else
throw(LinearOperatorException("Storage types cannot be promoted to a common eltype"))
end
end
return S
end

end
3 changes: 1 addition & 2 deletions src/NormalOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ end
LinearOperators.storage_type(op::NormalOpImpl) = typeof(op.Mv5)

function NormalOpImpl(parent, weights)
S = promote_type(storage_type(parent), storage_type(weights))
isconcretetype(S) || throw(LinearOperatorException("Storage types cannot be promoted to a concrete type"))
S = promote_storage_types(parent, weights)
tmp = S(undef, size(parent, 1))
return NormalOpImpl(parent, weights, tmp)
end
Expand Down
3 changes: 1 addition & 2 deletions src/ProdOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ composition/product of two Operators. Differs with * since it can handle normal
function ProdOp(A, B)
nrow = size(A, 1)
ncol = size(B, 2)
S = promote_type(LinearOperators.storage_type(A), LinearOperators.storage_type(B))
isconcretetype(S) || throw(LinearOperatorException("Storage types cannot be promoted to a concrete type"))
S = promote_storage_types(A, B)
tmp_ = S(undef, size(B, 1))

function produ!(res, x::AbstractVector{T}, tmp) where T<:Union{Real,Complex}
Expand Down

0 comments on commit 88f0cbc

Please sign in to comment.