Skip to content

Commit

Permalink
Introduce an IntervalSet type.
Browse files Browse the repository at this point in the history
  • Loading branch information
rofinn committed Jun 2, 2022
1 parent b6272ec commit 573ebe9
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 127 deletions.
1 change: 1 addition & 0 deletions src/Intervals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export Bound,
Unbounded,
AbstractInterval,
Interval,
IntervalSet,
AnchoredInterval,
HourEnding,
HourBeginning,
Expand Down
4 changes: 4 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,8 @@ function HB(anchor, inc::Inclusivity)
return HourBeginning{L,R}(floor(anchor, Hour))
end

@deprecate union(intervals::AbstractVector{<:AbstractInterval}) collect(union(IntervalSet(intervals)))
@deprecate union!(intervals::AbstractVector{<:AbstractInterval}) collect(union!(IntervalSet(intervals)))
@deprecate superset(intervals::AbstractVector{<:AbstractInterval}) superset(IntervalSet(intervals))

# END Intervals 1.X.Y deprecations
100 changes: 62 additions & 38 deletions src/interval_sets.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@

###### Set-related Helpers #####

const IntervalSet = AbstractVector{<:AbstractInterval}
struct IntervalSet{T<:AbstractInterval}
items::Vector{<:AbstractInterval}
end

IntervalSet(v::AbstractVector) = IntervalSet{eltype(v)}(v)
IntervalSet(interval::T) where T <: AbstractInterval = IntervalSet{T}([interval])
IntervalSet(interval::IntervalSet) = interval
IntervalSet(itr) = IntervalSet{eltype(itr)}(collect(itr))

Base.copy(intervals::IntervalSet{T}) where T = IntervalSet{T}(copy(intervals.items))
Base.length(intervals::IntervalSet) = length(intervals.items)
Base.iterate(intervals::IntervalSet, args...) = iterate(intervals.items, args...)
Base.eltype(::IntervalSet{T}) where T = T
Base.:(==)(a::IntervalSet, b::IntervalSet) = a.items == b.items
Base.isequal(a::IntervalSet, b::IntervalSet) = isequal(a, b)

const AbstractIntervals = Union{AbstractInterval, IntervalSet}

# During merge operations used to compute unions, intersections etc...,
Expand Down Expand Up @@ -39,9 +54,13 @@ function endpoint_tracking(
)
return TrackEachEndpoint()
end
endpoint_tracking(a::AbstractVector, b::AbstractVector) = endpoint_tracking(eltype(a), eltype(b))

endpoint_tracking(a::IntervalSet, b::IntervalSet) = endpoint_tracking(eltype(a), eltype(b))
endpoint_tracking(a::AbstractInterval, b::AbstractInterval) = endpoint_tracking(typeof(a), typeof(b))

# TODO: Delete once union deprecation is gone.
endpoint_tracking(a::AbstractVector, b::AbstractVector) = endpoint_tracking(eltype(a), eltype(b))

# track: run a thunk, but only if we are tracking endpoints dynamically
track(fn::Function, ::TrackEachEndpoint, args...) = fn(args...)
track(_, tracking::TrackStatically, args...) = tracking
Expand All @@ -58,11 +77,11 @@ interval_type(::TrackRightOpen{T}) where T = Interval{T, Closed, Open}
# all vectors of sets be represented by their endpoints. The functions unbunch
# and bunch convert between an interval and an endpoint representation

function unbunch(interval::AbstractInterval, tracking::EndpointTracking; lt=isless)
function unbunch(interval::AbstractInterval, tracking::EndpointTracking; lt=isless)
return endpoint_type(tracking)[LeftEndpoint(interval), RightEndpoint(interval)]
end
unbunch_by_fn(_) = identity
function unbunch(intervals::Union{AbstractIntervals, Base.Iterators.Enumerate{<:AbstractIntervals}},
function unbunch(intervals::Union{AbstractIntervals, Base.Iterators.Enumerate{<:AbstractIntervals}},
tracking::EndpointTracking; lt=isless)
by = unbunch_by_fn(intervals)
filtered = Iterators.filter(!isempty by, intervals)
Expand All @@ -72,7 +91,7 @@ function unbunch(intervals::Union{AbstractIntervals, Base.Iterators.Enumerate{<:
end
# support for `unbunch(enumerate(vcat(x)))` (transforming [(i, interval)] -> [(i, endpoint), (i,endpoint)])
unbunch_by_fn(::Base.Iterators.Enumerate) = last
function unbunch((i, interval)::Tuple, tracking; lt=isless)
function unbunch((i, interval)::Tuple, tracking; lt=isless)
eltype = Tuple{Int, endpoint_type(tracking)}
return eltype[(i, LeftEndpoint(interval)), (i, RightEndpoint(interval))]
end
Expand All @@ -84,19 +103,25 @@ function unbunch(a::AbstractIntervals, b::AbstractIntervals; kwargs...)
return a_, b_, tracking
end

# TODO: Delete fallback once union deprecation is removed
function unbunch(a::Vector{<:AbstractInterval}, b::Vector{<:AbstractInterval}; kwargs...)
return unbunch(IntervalSet(a), IntervalSet(b); kwargs...)
end

# represent a sequence of endpoints as a sequence of one or more intervals
function bunch(endpoints, tracking)
@assert iseven(length(endpoints))
isempty(endpoints) && return interval_type(tracking)[]
return map(Iterators.partition(endpoints, 2)) do pair
isempty(endpoints) && return IntervalSet(interval_type(tracking)[])
res = map(Iterators.partition(endpoints, 2)) do pair
return Interval(pair..., tracking)
end
return IntervalSet(res)
end
Interval(a::Endpoint, b::Endpoint, ::TrackEachEndpoint) = Interval(a, b)
Interval(a::Endpoint, b::Endpoint, ::TrackLeftOpen{T}) where T = Interval{T,Open,Closed}(a.endpoint, b.endpoint)
Interval(a::Endpoint, b::Endpoint, ::TrackRightOpen{T}) where T = Interval{T,Closed,Open}(a.endpoint, b.endpoint)

# the sentinel endpoint reduces the number of edgecases
# the sentinel endpoint reduces the number of edgecases
# we have to deal with when comparing endpoints during a merge
# NOTE: it's tempting to replace this with an unbounded endpoint
# but if we ever want to support unbounded endpoints in mergesets
Expand All @@ -108,7 +133,7 @@ function first_endpoint(x)
return eltype(x) <: Tuple ? last(first(x)) : first(x)
end
function last_endpoint(x)
isempty(x) && return SentinelEndpoint()
isempty(x) && return SentinelEndpoint()
# if the endpoints are enumerated, eltype will be a tuple
return eltype(x) <: Tuple ? last(last(x)) : last(x)
end
Expand Down Expand Up @@ -139,7 +164,7 @@ isleft(::RightEndpoint) = false
# will remain unchanged moving left to right along the real-number line until we encounter a
# new endpoint.
#
# For each endpoint, we determine two things:
# For each endpoint, we determine two things:
# 1. whether subsequent points should be included in the merge operation or not (based on
# its membership in both `x` and `y`) by using `op`
# 2. whether the next step will define a left (start including) or right endpoint (stop
Expand Down Expand Up @@ -264,35 +289,35 @@ right_endpoint(t, ::TrackRightOpen{T}) where T = RightEndpoint{T,Open}(endpoint(

# There is power in a union.
"""
union(intervals::AbstractVector{<:AbstractInterval})
union(intervals::IntervalSets)
Flattens a vector of overlapping intervals into a new, smaller vector containing only
non-overlapping intervals.
Flattens any overlapping intervals within the `IntervalSet` into a new, smaller set
containing only non-overlapping intervals.
"""
function Base.union(intervals::AbstractVector{<:AbstractInterval})
return union!(convert(Vector{AbstractInterval}, intervals))
end
Base.union(intervals::IntervalSet{<:Interval}) = union!(copy(intervals))

# allow a concretely typed array for `Interval` objects (as opposed to e.g. anchored intervals
# which may change type during the union process)
function Base.union(intervals::AbstractVector{T}) where T <: Interval
return union!(copy(intervals))
# In the case where we're dealing with a non-concrete interval type like AnchoredIntervals then simply
# allocate a AbstractInterval vector
function Base.union(intervals::IntervalSet{<:AbstractInterval})
T = AbstractInterval
return union!(IntervalSet{T}(convert(Vector{T}, intervals.items)))
end

"""
union!(intervals::AbstractVector{<:AbstractInterval})
union!(intervals::IntervalSet)
Flattens a vector of overlapping intervals in-place to be a smaller vector containing only
non-overlapping intervals.
"""
function Base.union!(intervals::AbstractVector{<:AbstractInterval})
sort!(intervals)
function Base.union!(intervals::IntervalSet)
items = intervals.items
sort!(items)

i = 2
n = length(intervals)
n = length(items)
while i <= n
prev = intervals[i - 1]
curr = intervals[i]
prev = items[i - 1]
curr = items[i]

# If the current and previous intervals don't meet then move along
if !overlaps(prev, curr) && !contiguous(prev, curr)
Expand All @@ -301,8 +326,8 @@ function Base.union!(intervals::AbstractVector{<:AbstractInterval})
# If the two intervals meet then we absorb the current interval into
# the previous one.
else
intervals[i - 1] = merge(prev, curr)
deleteat!(intervals, i)
items[i - 1] = merge(prev, curr)
deleteat!(items, i)
n -= 1
end
end
Expand All @@ -311,13 +336,13 @@ function Base.union!(intervals::AbstractVector{<:AbstractInterval})
end

"""
superset(intervals::AbstractArray{<:AbstractInterval}) -> Interval
superset(intervals::IntervalSet) -> Interval
Create the smallest single interval which encompasses all of the provided intervals.
"""
function superset(intervals::AbstractArray{<:AbstractInterval})
left = minimum(LeftEndpoint.(intervals))
right = maximum(RightEndpoint.(intervals))
function superset(intervals::IntervalSet)
left = minimum(LeftEndpoint.(intervals.items))
right = maximum(RightEndpoint.(intervals.items))

return Interval(left, right)
end
Expand All @@ -332,14 +357,14 @@ Base.issubset(x::AbstractIntervals, y::AbstractIntervals) = isempty(setdiff(x, y
Base.isdisjoint(x::AbstractIntervals, y::AbstractIntervals) = isempty(intersect(x, y))

function Base.issetequal(x::AbstractIntervals, y::AbstractIntervals)
x, y, tracking = unbunch(union(vcat(x)), union(vcat(y)))
x, y, tracking = unbunch(union(IntervalSet(x)), union(IntervalSet(y)))
return x == y || all(isempty, bunch(x, tracking)) && all(isempty, bunch(y, tracking))
end

# order edges so that closed boundaries are on the outside: e.g. [( )]
intersection_order(x::Endpoint) = isleft(x) ? !isclosed(x) : isclosed(x)
intersection_isless_fn(::TrackStatically) = isless
function intersection_isless_fn(::TrackEachEndpoint)
function intersection_isless_fn(::TrackEachEndpoint)
function (x,y)
if isequal(x, y)
return isless(intersection_order(x), intersection_order(y))
Expand All @@ -351,15 +376,15 @@ end

"""
find_intersections(
x::Union{AbstractInterval, AbstractVector{<:AbstractInterval}},
y::Union{AbstractInterval, AbstractVector{<:AbstractInterval}},
x::Union{AbstractInterval, IntervalSet},
y::Union{AbstractInterval, IntervalSet},
)
Returns a `Vector{Vector{Int}}` where the value at index `i` gives the indices to all
intervals in `y` that intersect with `x[i]`.
"""
function find_intersections(x_::AbstractIntervals, y_::AbstractIntervals)
xa, ya = vcat(x_), vcat(y_)
xa, ya = IntervalSet(x_), IntervalSet(y_)
tracking = endpoint_tracking(xa, ya)
lt = intersection_isless_fn(tracking)
x = unbunch(enumerate(xa), tracking; lt)
Expand Down Expand Up @@ -402,4 +427,3 @@ function find_intersections_helper!(result, x, y, lt)

return unique!.(result)
end

Loading

0 comments on commit 573ebe9

Please sign in to comment.