Skip to content

Commit

Permalink
Use type hierachy for matrix iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed Aug 13, 2024
1 parent 99299af commit a7c4cf1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/MultiThreading.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export SequentialState, MultiThreadingState
export SequentialState, MultiThreadingState, prepareMultiStates
abstract type AbstractMatrixSolverState{S} <: AbstractSolverState{S} end
mutable struct SequentialState{S, ST <: AbstractSolverState{S}} <: AbstractMatrixSolverState{S}
states::Vector{ST}
Expand Down Expand Up @@ -30,14 +30,14 @@ function prepareMultiStates(solver::AbstractLinearSolver, state::AbstractSolverS
end
prepareMultiStates(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix) = prepareMultiStates(solver, first(state.states), b)

function init!(solver::AbstractLinearSolver, state::Union{SequentialState, MultiThreadingState}, b::AbstractMatrix; kwargs...)
function init!(solver::AbstractLinearSolver, state::AbstractMatrixSolverState, b::AbstractMatrix; kwargs...)
for (i, s) in enumerate(state.states)
init!(solver, s, b[:, i]; kwargs...)
end
state.active .= true
end

function iterate(solver::S, state::Union{SequentialState, MultiThreadingState}) where {S <: AbstractLinearSolver}
function iterate(solver::S, state::AbstractMatrixSolverState) where {S <: AbstractLinearSolver}
activeIdx = findall(state.active)
if isempty(activeIdx)
return nothing
Expand Down
2 changes: 1 addition & 1 deletion test/testMultiThreading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function testMultiThreadingSolver(; arrayType = Array, scheduler = MultiDataStat
x = rand(ComplexF32, 2, 4)
b = A * x

solvers = [CGNR] # linearSolverList()
solvers = linearSolverList()
@testset "$(solvers[i])" for i = 1:length(solvers)
S = createLinearSolver(solvers[i], arrayType(A), iterations = 100)

Expand Down

0 comments on commit a7c4cf1

Please sign in to comment.