Skip to content

Commit

Permalink
Add deprecated_names to the IndexEntry and include a search funct…
Browse files Browse the repository at this point in the history
…ion for selection entries by fullname or deprecated name.
  • Loading branch information
rofinn committed Jul 12, 2022
1 parent a708f90 commit e8b6484
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
91 changes: 85 additions & 6 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
IndexEntry(checkpoint_path, base_dir)
IndexEntry(checkpoint_path, base_dir, depnames_lookup=_depnames_lookup())
IndexEntry(checkpoint_path, checkpoint_name, prefixes, tags)
This is an index entry describing the output file from a checkpoint.
Expand All @@ -22,9 +22,10 @@ struct IndexEntry
checkpoint_name::AbstractString
prefixes::NTuple{<:Any, AbstractString}
tags::NTuple{<:Any, Pair{Symbol, <:AbstractString}}
deprecated_names::NTuple{<:Any, AbstractString}
end

function IndexEntry(filepath::AbstractPath, base_dir)
function IndexEntry(filepath::AbstractPath, base_dir, depnames_lookup=_depnames_lookup())
if dirname(filepath) == base_dir
# workaround for relpath erroring on equal S3Paths
# https://github.com/rofinn/FilePathsBase.jl/issues/156
Expand All @@ -39,8 +40,23 @@ function IndexEntry(filepath::AbstractPath, base_dir)
return Symbol(tag)=>val
end
end

checkpoint_name = filename(filepath)
return IndexEntry(filepath, checkpoint_name, prefixes, tags)
checkpoint_fullname = join((prefixes..., checkpoint_name), ".")
if haskey(depnames_lookup, checkpoint_fullname)
deprecated_names = Tuple(depnames_lookup[checkpoint_fullname])
else
filtered = filter(e -> checkpoint_fullname in last(e), depnames_lookup)
if isempty(filtered)
deprecated_names = ()
else
k, v = only(filtered)
checkpoint_name = last(split(k, "."))
deprecated_names = Tuple(v)
end
end

return IndexEntry(filepath, checkpoint_name, prefixes, tags, deprecated_names)
end


Expand Down Expand Up @@ -103,11 +119,27 @@ Note that if the tags are unique, then their values call also be accessed via a
"""
tags(x::IndexEntry) = getfield(x, :tags)

"""
deprecated_names(x::IndexEntry)
Previous `checkpoint_name`s that have since been renamed.
If the checkpoint was previously saved used `checkpoint(Forecasters, "predictions", ...)`,
but has since been renamed to `checkpoint(Forecasters, "forecasts", ...)` then
predictions" would live in this list.
"""
deprecated_names(x::IndexEntry) = getfield(x, :deprecated_names)

_tag_names(x::IndexEntry) = first.(tags(x))

#Tables.columnnames(x::IndexEntry) = propertynames(x)
function Base.propertynames(x::IndexEntry)
return [:prefixes, :checkpoint_name, _tag_names(x)..., :checkpoint_path]
return [
:prefixes,
:checkpoint_name,
_tag_names(x)...,
:checkpoint_path,
:deprecated_names,
]
end

function Base.getproperty(x::IndexEntry, name::Symbol)
Expand Down Expand Up @@ -162,8 +194,10 @@ You can also work with it directly, say you wanted to get all checkpoints files
"""
function index_checkpoint_files(dir::AbstractPath)
isdir(dir) || throw(ArgumentError("Need an existing directory."))
depnames_lookup = _depnames_lookup()

map(Iterators.filter(==("jlso") extension, walkpath(dir))) do checkpoint_path
return IndexEntry(checkpoint_path, dir)
return IndexEntry(checkpoint_path, dir, depnames_lookup)
end
end

Expand All @@ -176,9 +210,54 @@ Constructs a index for all the files located within `dir`.
Same as [`index_checkpoint_files`] except not restricted to files created by Checkpoints.jl.
"""
function index_files(dir::AbstractPath)
depnames_lookup = _depnames_lookup()
map(Iterators.filter(isfile, walkpath(dir))) do path
return IndexEntry(path, dir)
return IndexEntry(path, dir, depnames_lookup)
end
end

index_files(dir) = index_files(Path(dir))

"""
search(name::AbstractString, index)
Returns elements where `name` matches either the full checkpoint name or deprecated names.
If the `name` is deprecated then a deprecation warning is thrown.
# Arguments
- `name`: The full checkpoint name to search for (ie: `"Forecasters.forecasts"`)
- `index`: Iterable of `IndexEntry` elements
"""
function search(name::AbstractString, index)
results = filter(index) do idx
(
checkpoint_fullname(idx) == name ||
name in idx.deprecated_names
)
end

isempty(results) && return results

fullname = checkpoint_fullname(first(results))
name == fullname || Base.depwarn("$name has been deprecated to $fullname.", :search)
return results
end


# Utility function for generating a checkpoint name lookup table from the current registry
function _depnames_lookup()
deps = deprecated_checkpoints()
results = Dict{String, Set{String}}(
x => Set{String}() for x in setdiff(available(), keys(deps))
)

# Simple recursive find_name function to find the original non-deprecated name
find_name(x) = haskey(deps, x) ? find_name(deps[x]) : x

for (prev, curr) in deps
k = find_name(curr)
push!(results[k], prev)
end

return results
end
20 changes: 20 additions & 0 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@
end
end

@testset "Searching for deprecated checkpoint" begin
mktempdir(SystemPath) do path
@test_deprecated Checkpoints.config("TestPkg.quuz", path)
a = Dict(zip(
map(x -> Symbol(randstring(4)), 1:10),
map(x -> rand(10), 1:10)
))
b = rand(10)
TestPkg.qux(a, b)

index = index_checkpoint_files(path)
entry= only(index)
@test checkpoint_name(entry) == "qux_b"
@test checkpoint_fullname(entry) == "TestPkg.qux_b"
@test Checkpoints.deprecated_names(entry) == ("TestPkg.quuz",)
res = @test_deprecated Checkpoints.search("TestPkg.quuz", index)
@test res == index
end
end

@testset "files not saved by Checkpoints.jl" begin
mktempdir(SystemPath) do path
Checkpoints.config("TestPkg.bar", path)
Expand Down

0 comments on commit e8b6484

Please sign in to comment.