Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

help type inference for logical indexing #29633

Merged
merged 1 commit into from
Oct 29, 2018

Conversation

chethega
Copy link
Contributor

Before:

julia> using BenchmarkTools
julia> d=rand(10_000); m=rand(Bool, 10_000);
julia> @btime getindex($d,$m);
  742.062 μs (23906 allocations: 643.13 KiB)

After:

julia> using BenchmarkTools
julia> d=rand(10_000); m=rand(Bool, 10_000);
julia> @btime getindex($d,$m);
  51.177 μs (4 allocations: 39.06 KiB)

Not sure whether really related to #29418. See also https://discourse.julialang.org/t/is-boolean-indexing-100-times-slower-in-1-0/16286/8.

Problem was that we have a type change in the loop between the very first and subsequent iterations over logical indexing and this confused inference.

@Keno
Copy link
Member

Keno commented Oct 14, 2018

Would it make sense to write an unroll once macro instead and use it here? Just want to avoid the iterations getting out of sync (plus the macro might come in handy elsewhere with the same issue).

@chethega
Copy link
Contributor Author

I don't think readability, maintainability or linecount would improve from turning this into a macro, unless we get the issue in many places. If many places have already unrolled the first iteration by hand, then this could pay off.

Long-term, it would be nice if the compiler could learn to better deal with loops that are stable only after the first iteration.

@nalimilan
Copy link
Member

Thanks! Out of curiosity, why does the type change? I still don't get it.

It doesn't look like BaseBenchmarks cover this, which may explain why this regression wasn't noticed.

@chethega
Copy link
Contributor Author

When calling iterate(L::LogicalIndex{Int}) we then call iterate(L, (1,LinearIndices(L.mask))).

Now we get a problem: Initially, s = (1, LinearIndices(L.mask)) is a two-element tuple, but subsequently it is a three-element tuple s = (n+1, LinearIndices(L.mask), i).

In theory, it should be possible to teach the compiler to split and infer loops whenever the transition graph between types of internal states is acyclic. Apparently we're not there yet, and I am unsure whether such a thing should be a priority.

@KristofferC KristofferC added performance Must go faster potential benchmark Could make a good benchmark in BaseBenchmarks labels Oct 14, 2018
@nalimilan
Copy link
Member

Thanks, I see. Then maybe changing iterate(tail(s)...) to iterate(s[2]) and iterate(s[2], s[3]) would be more explicit. Not a big deal though.

@chethega
Copy link
Contributor Author

@KristofferC You added the "needs benchmark" label, presumably in order to catch any possible future regressions?

I'm not familiar with the procedure here. Should I make a separate PR to JuliaCI/BaseBenchmarks.jl?
Or can you do that, I read your PR and next time I know what to do?

@vchuravy
Copy link
Member

I'm not familiar with the procedure here. Should I make a separate PR to JuliaCI/BaseBenchmarks.jl?

Yes a separate PR like JuliaCI/BaseBenchmarks.jl#233 would do the trick.

@KristofferC
Copy link
Member

KristofferC commented Oct 15, 2018

The needs benchmark is just a tag to make sure we at some point make a PR to BaseBenchmarks.jl to regress test this (eg JuliaCI/BaseBenchmarks.jl#233). It is not a requirement for merging.

@chethega
Copy link
Contributor Author

Thanks! I'll leave the benchmark to you, then.

@KristofferC
Copy link
Member

Help is of course appreciated :)

Copy link
Member

@mbauman mbauman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This is definitely needed — but it would indeed be much clearer if we actually named the components of the state tuple as Milan suggested. I think we had initially used r and i here (see the comments) but those got lost in the translation to the new iteration spec.

@chethega
Copy link
Contributor Author

chethega commented Oct 15, 2018

The problem when being more explicit about the components of the state tuple is that the type change only occurs in iterate(L), not subsequent iterations. I.e., iterating over a logical index is actually a double iteration, and the type-change only occurs from first-first to first-second, and not for second-first to second-second.

In other words, starting with second-*, the unroll is actually bad: Larger code-size and we execute the same number of branches, but distributed over two code positions, which is bad for branch prediction. I am kinda hoping that this either gets optimized out by llvm or that it doesn't matter.

Generally, I think a similar issue should appear in any code that flattens iterators.

The current code has the nice property of being agnostic about the types of the inner iterator states. So we could either become very explicit, with unneeded type annotations and running the risk of having bugs with too-narrow dispatch, or stay with this.

Alternatively, we could of course call iterate(L, (1, r, 0)), but that would break for logical indices over AbstractArray with weirder iterator states. Or we could introduce a branch in iterate(L) that checks the first element of L.mask. But these solutions look far less clean to me.

The latter fix would look like

 @inline function iterate(L::LogicalIndex{Int})
     r = LinearIndices(L.mask)
     L.count == 0 && return nothing
     s = (2, r, 1)
     L.mask[1] && return (1, s)    
     iterate(L, s)
 end

@mbauman
Copy link
Member

mbauman commented Oct 15, 2018

Ah, right. Thanks for the thorough explanation, that makes sense. I still would love to see if the inference folks could fix this case for us, but in the meantime this is a much-needed patch.

I find the nested tuple version a little easier to reason about, but of course it hits the same inference issue:

@inline function iterate(L::LogicalIndex{Int})
    r = LinearIndices(L.mask)
    iterate(L, (1, (r,)))
end
@inline function iterate(L::LogicalIndex{<:CartesianIndex})
    r = CartesianIndices(axes(L.mask))
    iterate(L, (1, (r,)))
end
@propagate_inbounds function iterate(L::LogicalIndex, s)
    # We're looking for the n-th true element, using the `inner` iterator
    n, inner = s
    n > length(L) && return nothing
    while true
        idx, i = iterate(inner...)
        inner = (inner[1], i)
        L.mask[idx] && return (idx, (n+1, inner))
    end
end

@StefanKarpinski
Copy link
Member

StefanKarpinski commented Oct 16, 2018

Let's have the low-risk performance fix that can go into 1.0.2 now and have the more general but higher risk inference-fix version in 1.1 or 1.2.

@chethega
Copy link
Contributor Author

Bump. Anything else needed?

The inference fail on 1.0 is arguably a bug, even though the behavior is correct. Can we add backport label, as Stefan implied?

@StefanKarpinski
Copy link
Member

It would be good if someone with more inference-fu could take a look and merge if this is ok.

@KristofferC
Copy link
Member

Small change, demonstrated performance improvements, approved by Mr Bauman, Three of a kind CI. I'm merging :P

@KristofferC KristofferC merged commit aa72f72 into JuliaLang:master Oct 29, 2018
KristofferC pushed a commit that referenced this pull request Oct 29, 2018
@vtjnash
Copy link
Member

vtjnash commented Oct 29, 2018

I can make almost the same improvement with less code, although I feel surprised it's not faster (since it should also be removing a couple of branches):

diff --git a/base/multidimensional.jl b/base/multidimensional.jl
index d79715299f..614b23ba8c 100644
--- a/base/multidimensional.jl
+++ b/base/multidimensional.jl
@@ -507,20 +507,20 @@ show(io::IO, r::LogicalIndex) = print(io, "Base.LogicalIndex(", r.mask, ")")
 # should be -- this way we don't need to look at future indices to check done.
 @inline function iterate(L::LogicalIndex{Int})
     r = LinearIndices(L.mask)
-    iterate(L, (1, r))
+    iterate(L, (r,))
 end
 @inline function iterate(L::LogicalIndex{<:CartesianIndex})
     r = CartesianIndices(axes(L.mask))
-    iterate(L, (1, r))
+    iterate(L, (r,))
 end
 @propagate_inbounds function iterate(L::LogicalIndex, s)
     # We're looking for the n-th true element, using iterator r at state i
-    n = s[1]
-    n > length(L) && return nothing
-    while true
-        idx, i = iterate(tail(s)...)
-        s = (n+1, s[2], i)
-        L.mask[idx] && return (idx, s)
+    r = s[1]
+    next = iterate(s...)
+    while next !== nothing
+        idx, i = next
+        L.mask[idx] && return idx, (r, i)
+        next = iterate(r, i)
     end
 end
 # When wrapping a BitArray, lean heavily upon its internals.

This is like Matt's code, but also removes the hack to make the old iteration protocol work (it used to require a separate count of the elements, now that's optional).

This would be somewhat easier on the compiler (since performance scales super-linearly with the number of statements), and perhaps then permitting more inlining.

@mbauman
Copy link
Member

mbauman commented Oct 29, 2018

the hack to make the old iteration protocol work (it used to require a separate count of the elements…)

Oh, well that wasn't just a hack for iteration's sake — it also enabled short-circuiting for indices with lots of falses at the end. Is that where you're seeing a performance loss?

This should be a very well-predicted branch until it isn't, at which point we leave the iteration loop entirely. It should be a fairly cheap trick for big wins in some situations.

KristofferC pushed a commit that referenced this pull request Oct 31, 2018
KristofferC pushed a commit that referenced this pull request Nov 2, 2018
KristofferC pushed a commit that referenced this pull request Feb 11, 2019
KristofferC pushed a commit that referenced this pull request Feb 20, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Must go faster potential benchmark Could make a good benchmark in BaseBenchmarks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants