Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpolated stats #203

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/Output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,25 @@ show(io::IO, o::MemoryOutput) = print(io, "MemoryOutput$(collect(keys(o.data)))"
"""
function (o::MemoryOutput)(y, t, dt, yfun)
save, ts = o.save_cond(y, t, dt, o.saved)
append_stats!(o, o.statsfun(y, t, dt))
!haskey(o.data, o.yname) && initialise(o, y)
while save
ys = yfun(ts)
append_stats!(o, o.statsfun(ys, ts, dt))
s = size(o.data[o.yname])
if s[end] < o.saved+1
o.data[o.yname] = fastcat(o.data[o.yname], yfun(ts))
o.data[o.yname] = fastcat(o.data[o.yname], ys)
push!(o.data[o.tname], ts)
else
idcs = fill(:, ndims(y))
o.data[o.yname][idcs..., o.saved+1] = yfun(ts)
o.data[o.yname][idcs..., o.saved+1] = ys
o.data[o.tname][o.saved+1] = ts
end
o.saved += 1
save, ts = o.save_cond(y, t, dt, o.saved)
end
if ts != t
append_stats!(o, o.statsfun(y, t, dt))
end
end

function append_stats!(o::MemoryOutput, d)
Expand Down Expand Up @@ -209,7 +213,7 @@ function HDF5Output(fpath::AbstractString)
HDF5Output(fpath, 0, 0, 1; readonly=true)
end

function initialise(o::HDF5Output, y)
function initialise(o::HDF5Output, y, statsnames)
ydims = size(y)
idims = init_dims(ydims, o.save_cond)
cdims = collect(idims)
Expand All @@ -228,7 +232,6 @@ function initialise(o::HDF5Output, y)
end
HDF5.create_dataset(file, o.tname, HDF5.datatype(Float64), ((dims[end],), (-1,)),
chunk=(1,))
statsnames = sort(collect(keys(o.stats_tmp[end])))
o.cachehash = hash((statsnames, size(y)))
file["meta"]["cachehash"] = o.cachehash
if o.cache
Expand Down Expand Up @@ -300,22 +303,26 @@ end
function (o::HDF5Output)(y, t, dt, yfun)
o.readonly && error("Cannot add data to read-only output!")
save, ts = o.save_cond(y, t, dt, o.saved)
push!(o.stats_tmp, o.statsfun(y, t, dt))
if save
@hlock HDF5.h5open(o.fpath, "r+") do file
!HDF5.haskey(file, o.yname) && initialise(o, y)
statsnames = sort(collect(keys(o.stats_tmp[end])))
cachehash = hash((statsnames, size(y)))
cachehash == o.cachehash || error(
"the hash for this propagation does not agree with cache in file")
statsnames = sort(collect(keys(o.statsfun(y, t, dt))))
if !HDF5.haskey(file, o.yname)
initialise(o, y, statsnames)
else
cachehash = hash((statsnames, size(y)))
cachehash == o.cachehash || error(
"the hash for this propagation does not agree with cache in file")
end
while save
s = collect(size(file[o.yname]))
idcs = fill(:, length(s)-1)
if s[end] < o.saved+1
s[end] += 1
HDF5.set_extent_dims(file[o.yname], Tuple(s))
end
file[o.yname][idcs..., o.saved+1] = yfun(ts)
ys = yfun(ts)
push!(o.stats_tmp, o.statsfun(ys, ts, dt))
file[o.yname][idcs..., o.saved+1] = ys
s = collect(size(file[o.tname]))
if s[end] < o.saved+1
s[end] += 1
Expand All @@ -335,6 +342,9 @@ function (o::HDF5Output)(y, t, dt, yfun)
end
end
end
if ts != t
push!(o.stats_tmp, o.statsfun(y, t, dt))
end
end

function append_stats!(parent, a::Array{Dict{String,Any},1})
Expand Down
5 changes: 5 additions & 0 deletions test/test_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ fpath_comp = joinpath(dirpath, "test_comp.h5")
ω, Eω, zac = Processing.getEω(o, 5e-2)
@test (ω, Eω[:, 1]) == Processing.getEω(grid, mem["Eω"][:, 51])
@test zac[1] == 5e-2
# test stat save locations
@test length(unique(mem["stats"]["z"])) == length(mem["stats"]["z"])
@test length(unique(o["stats"]["z"])) == length(o["stats"]["z"])
@test all([z in o["stats"]["z"] for z in o["z"]])
@test all([z in mem["stats"]["z"] for z in mem["z"]])
end
rm(fpath)
rm(fpath_comp)
Expand Down