Skip to content

Commit 171edd0

Browse files
fix: fix memory aliasing with array discrete parameters
1 parent e83ff9b commit 171edd0

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/systems/parameter_buffer.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -830,19 +830,27 @@ function SciMLBase.create_parameter_timeseries_collection(
830830
isempty(ps.discrete) && return nothing
831831
num_discretes = only(blocksize(ps.discrete[1]))
832832
buffers = []
833-
partition_type = Tuple{(typeof(parent(buf)) for buf in ps.discrete)...}
833+
partition_type = typeof(SciMLBase.get_saveable_values(sys, ps, 1))
834834
for i in 1:num_discretes
835835
ts = eltype(tspan)[]
836-
us = NestedGetIndex{partition_type}[]
836+
us = partition_type[]
837837
push!(buffers, DiffEqArray(us, ts, (1, 1)))
838838
end
839839

840840
return ParameterTimeseriesCollection(Tuple(buffers), copy(ps))
841841
end
842842

843+
@inline __get_blocks(tsidx::Int) = ()
844+
@inline function __get_blocks(tsidx::Int, buffer::BlockedArray, buffers...)
845+
(buffer[Block(tsidx)], __get_blocks(tsidx, buffers...)...)
846+
end
847+
@inline function __get_blocks(tsidx::Int, buffer::BlockedArray{<:AbstractArray}, buffers...)
848+
(copy.(buffer[Block(tsidx)]), __get_blocks(tsidx, buffers...)...)
849+
end
850+
843851
function SciMLBase.get_saveable_values(
844852
sys::AbstractSystem, ps::MTKParameters, timeseries_idx)
845-
return NestedGetIndex(Tuple(buffer[Block(timeseries_idx)] for buffer in ps.discrete))
853+
return NestedGetIndex(__get_blocks(timeseries_idx, ps.discrete...))
846854
end
847855

848856
function save_callback_discretes!(integ::SciMLBase.DEIntegrator, callback)

0 commit comments

Comments
 (0)