Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize find_intersections for closed intervals #203

Merged
merged 15 commits into from
Jun 14, 2023

Conversation

ericphanson
Copy link
Contributor

@ericphanson ericphanson commented Oct 13, 2022

This adds a method for find_intersections with closed-endpoints which is much faster, at least in my example problem for which the current performance too slow. The algorithm is just based on sorting and binary searching, and is based on some code that I wrote for a beacon-internal project last week which I didn’t realize was actually in fact basically a solution for this problem.

This speeds up the problem in beacon-biosignals/DataFrameIntervals.jl#24 from 22s to 1s (with N_POINTS = 10^5 and N_INTERVALS = 1000), and more importantly to me, the N_POINTS = 10^6 and N_INTERVALS = 3000 version now takes 35s. I don't have a timing with the current code on that exact problem, but a very similar one (off which the MWE was based) took 2.7 hours.

This is covered by the current tests (the ones with closed-closed intersections), but I can add specific tests if someone thinks something in particular would be good to cover.

This algorithm works very differently depending on which argument comes first, but I get similar performance (on my example problem at least) no matter which argument is first. The number of individual allocations does change by a lot, however:

julia> @time find_intersections(L.interval, R.interval);
  4.324992 seconds (98.94 k allocations: 1.765 GiB, 1.87% gc time)

julia> @time find_intersections(R.interval, L.interval);
  4.504758 seconds (11.01 M allocations: 1.834 GiB, 1.95% gc time)

(this is on the N_POINTS = 10^6 and N_INTERVALS = 3000 version of the example problem; the timings ~5s instead of 35s because this is only the intersection portion; now the DataFramesIntervals overhead dominates, whereas it used to be the find_intersections time).

I am not sure this is so much faster in all cases, or only this particular problem. It however also speeds up the case of interval_join(L, L; on=:interval, makeunique=true) from 22s to 8s. This case is a very different shape than the L, R case, so since it speeds up both cases I suspect it's just a faster appraoch. However if someone suggests other problems I can try them too.


I only did the closed-closed case because I didn't want to think about all the endpoints. I would appreciate if we could leave other endpoints to future work, since this should already be a big improvement in some cases at least.

src/interval_sets.jl Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Oct 13, 2022

Codecov Report

Merging #203 (7c31a48) into master (ba938f6) will increase coverage by 0.23%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #203      +/-   ##
==========================================
+ Coverage   84.31%   84.55%   +0.23%     
==========================================
  Files          12       12              
  Lines         848      874      +26     
==========================================
+ Hits          715      739      +24     
- Misses        133      135       +2     
Impacted Files Coverage Δ
src/interval_sets.jl 92.10% <100.00%> (-0.05%) ⬇️

... and 2 files with indirect coverage changes

@ericphanson
Copy link
Contributor Author

CI failures are #199

@haberdashPI
Copy link
Contributor

Awesome! The commenting really helped. I think this is pretty clear, and I didn't notice any bugs.

So, want to make sure I understand this correctly. Since I'd like to generalize this to all interval types at some point.

In essence what's going on is that rather than scanning along for each interval that intersects with A, you can sorted-search for the last such interval, get its index, and then compute intersections in the UnitRange domain. Cool!

In principle you could have just 1 such interval in A per interval in B so I think the complexity class remains: O(N log N) + O(M log M) + O(K log max(N,M)) where N and M are the lengths of A and B interval sets and M the length of the output. However for use cases where there are many intervals in B for the intervals in A (or vice versa) this will be faster and that's actually a pretty common use case!

@ericphanson
Copy link
Contributor Author

ericphanson commented Oct 14, 2022

In essence what's going on is that rather than scanning along for each interval that intersects with A, you can sorted-search for the last such interval, get its index, and then compute intersections in the UnitRange domain. Cool!

Yep!

O(N log N) + O(M log M) + O(K log max(N,M)) where N and M are the lengths of A and B interval sets and M the length of the output

What is K here? The length of the flattened output?

I think this sounds roughly right. I think of it as: we need to sort B, so that's M*log(M), then for each element of A we do log(M) work in binary searching, so that's N*log(M), plus some amount of work for the number of intersections we actually find (to do our permuting). I guess that adds up to K, once we sum over all elements in A, if K is the total number of intersections (length of flattened output). So overall, M*log(M) + N*log(M) + K, I think, taking K as the total number of intersections. Does that seem right?

edit: BTW I think in 1.9 we get a radix sort in Base, which means sort complexity for integers and floats should be linear, bringing the complexity to M + N*log(M) + K, for a lil speedup in the sorting.

@haberdashPI
Copy link
Contributor

Yes, by K I mean the length of the output.

Cool, that makes sense. I think that should be enough for me to work on a version that generalizes to all interval types.

https://github.com/JuliaCollections/SortingAlgorithms.jl would be another option if we want to use a radix sort.

@ericphanson
Copy link
Contributor Author

ericphanson commented Oct 17, 2022

I think we can just wait for the one in Base; sorting was not the bottleneck when I was profiling this.

If you were able to give it a look, could you do a review here @haberdashPI? That might help get it in. (I can’t request review on this repo)

src/interval_sets.jl Outdated Show resolved Hide resolved
src/interval_sets.jl Outdated Show resolved Hide resolved
src/interval_sets.jl Show resolved Hide resolved
Co-authored-by: Curtis Vogt <curtis.vogt@gmail.com>
@ericphanson
Copy link
Contributor Author

On the Julia slack (#performance-helpdesk), @miguelraz pointed out:

couldn't you sort the x by the last(I), and if any of the idx_last > len then you know the search is done

I think that could speed things up a lot in the case where all of the y's are to the left of many of the xs, since once we find the first such case, we know the rest are empty.

I think we could also go the other way, sorting x by first, to handle the case in which all of the y's are to the right of many of the x's.

I think a generalization of this is: if the x's are sorted by first, then as we iterate through the loop over x, we know idx_last should be monotone-increasing, and we can do shorter and shorter binary searches by starting after the previous idx_last. Likewise, if the xs were sorted by last, idx_first should be monotone-decreasing. Then we recover the "all to the right" and "all to the left" cases by noticing once idx_first or idx_last becomes extremal.

I am not quite sure if we could somehow use two sorted copies of x like we do with y, in order to get the benefits of both sortings. I think perhaps we could by delaying the intersections. In other words, we first compute: for each x_start, what are all the y_stops that occur after it, and for each x_stop, what are all the y_starts that occur before it. We should be able to work in sorted coordinates such that these are all UnitIntervals, so they aren't too heavy. Each of these lists is in a different order, coresponding to the sort-x-by-first and sort-x-by-last orders. Then we un-permute both orders to align back to the original order of x, and perform the intersections there.

This will be slightly more allocation heavy, but not incredibly so: it means we need to additionally allocate 2 UnitRanges per element of x, but that's a lot better than say 2 Vectors. We will also have a few more copies of x lying around, due to the additional sorting and permutation vectors. We can allocate everything upfront however, so we don't need to allocate in the loops themselves. I think this is probably worth it. It should be much faster in cases where many xs don't intersect with many ys, which I am guessing is most of the time that folks use find_intersections. At least in my use cases, I often have small intervals (one second or less) spread over hours.

@ericphanson
Copy link
Contributor Author

I tried the above approach in ericphanson/Intervals.jl@eph/faster-intersections...ericphanson:Intervals.jl:eph/optimize-find-intersections-2. Disappointingly, it doesn't seem much faster, even when I try to shift around the intervals' ranges to better hit these optimizations:

using DataFrames, StableRNGs, Intervals
using Intervals: find_intersections

function make_example(n_intervals, n_points; point_max=10_000, interval_max=10_000)
    rng = StableRNG(123)
    points = [rand(rng, 1:point_max) for _ in 1:n_points]
    L = DataFrame(:point => points)
    transform!(L, :point => ByRow(p -> Interval{Int,Closed,Closed}(p, p)) => :interval)

    indices = map(1:n_intervals) do _
        i = rand(rng, 1:interval_max)
        return i:i+100
    end
    R = DataFrame(:indices => indices)
    transform!(R, :indices => ByRow(I -> Interval{Int,Closed,Closed}(first(I), last(I))) => :interval)
    return L, R
end

L, R = make_example(3000, 10^6; point_max=100_000, interval_max=10_000)


@time find_intersections(L.interval, R.interval);
@time find_intersections(R.interval, L.interval);

For this PR, I get

julia> @time find_intersections(L.interval, R.interval);
  0.640732 seconds (2.01 M allocations: 256.068 MiB, 21.22% gc time)

julia> @time find_intersections(R.interval, L.interval);
  4.522413 seconds (72.02 k allocations: 294.140 MiB, 3.77% gc time)

For that branch, I get

julia> @time find_intersections(L.interval, R.interval);
  0.693489 seconds (2.01 M allocations: 317.866 MiB, 8.34% gc time)

julia> @time find_intersections(R.interval, L.interval);
  4.437335 seconds (72.04 k allocations: 294.323 MiB, 1.13% gc time)

@ericphanson
Copy link
Contributor Author

BTW, I found I found https://link.springer.com/article/10.1007/s00778-020-00639-0, and the usual find_intersections algorithm looks similar to the endpoint-based interval join described there.

@haberdashPI
Copy link
Contributor

haberdashPI commented Oct 21, 2022

Wow... it's crazy to me that the first, most obvious way I thought of to make this a little more efficient was published in 2016! I would have expected some publications before that...

Looks like that paper might have some more optimizations to consider when we're ready to make this even faster.

@ericphanson
Copy link
Contributor Author

bump

@ericphanson
Copy link
Contributor Author

bump

@haberdashPI
Copy link
Contributor

haberdashPI commented Jan 17, 2023

I've made a few small changes here: ericphanson#1

I believe this would be a simple way to get the faster algorithm working for all interval endpoints. (Since #205 is setup without including this change I have yet to run it on that more comprehensive set of tests). Just wanted to post here for transparency, I suspect that it should ultimatley become a PR to this version of the repo, once this and #205 merge.

@ericphanson
Copy link
Contributor Author

Yeah, I would really prefer we got this PR in as soon as possible and then incrementally expanded the functionality. We are currently forced to pirate this function in a lot of downstream code because the current performance is untenable.

@ericphanson
Copy link
Contributor Author

bump

@omus
Copy link
Collaborator

omus commented Jun 14, 2023

Actively reviewing this PR today

@omus
Copy link
Collaborator

omus commented Jun 14, 2023

Current code fails with the tests added in #205. Specifically the "equal -0.0/0.0" testset. Looks like using lt=< as part of the searchsorted(first/last) functions can address this

src/interval_sets.jl Outdated Show resolved Hide resolved
src/interval_sets.jl Outdated Show resolved Hide resolved
src/interval_sets.jl Outdated Show resolved Hide resolved
src/interval_sets.jl Outdated Show resolved Hide resolved
src/interval_sets.jl Show resolved Hide resolved
src/interval_sets.jl Show resolved Hide resolved
src/interval_sets.jl Outdated Show resolved Hide resolved
src/interval_sets.jl Outdated Show resolved Hide resolved
@omus
Copy link
Collaborator

omus commented Jun 14, 2023

@haberdashPI pointed out ericphanson#1 which generalizes this for all intervals. I'll make a follow up PR for that which will be the 1.10.0 release

@omus
Copy link
Collaborator

omus commented Jun 14, 2023

Re-ran the benchmark posted above.

With the lt=< fix (7c31a48):

julia> @time find_intersections(L.interval, R.interval);
  0.655631 seconds (2.64 M allocations: 287.892 MiB, 6.09% gc time, 20.41% compilation time)

julia> @time find_intersections(R.interval, L.interval);
  5.724341 seconds (72.02 k allocations: 294.140 MiB, 0.94% gc time)

This PR before the fix (7294a56)

julia> @time find_intersections(L.interval, R.interval);
  0.686747 seconds (2.61 M allocations: 286.352 MiB, 7.78% gc time, 19.47% compilation time)

julia> @time find_intersections(R.interval, L.interval);
  5.218349 seconds (72.02 k allocations: 294.140 MiB, 0.85% gc time)

@omus omus merged commit 4b9f831 into invenia:master Jun 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants