Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type stability of cat_pullback #610

Merged
merged 1 commit into from
Apr 26, 2022
Merged

Conversation

ToucheSir
Copy link
Contributor

Currently,

index = ntuple(ndimsY) do d
if d in cdims
d > ndimsX ? (prev[d]+1) : (prev[d]+1:prev[d]+sizeX[d])
else
d > ndimsX ? 1 : (:)
end
end
will generate a Tuple{Union{UnitRange{Int64}, Colon}, ...}. When the number of input dims >=3, this appears to fail some union-splitting threshold for getindex on certain array types. Here's a MWE with FillArrays, which Zygote uses somewhat extensively:

using Test, ChainRules, ChainRulesCore, FillArrays

function grad(x1, x2, sens, dims::Val)
  y, back = rrule(cat, x1, x2; dims)
  back(sens)
end

# Infers
@inferred grad(ones(1, 1, 1), ones(1, 1, 1), ones(1, 2, 1), Val(2))
@inferred grad(Ones(1, 1), Ones(1, 1), ones(1, 2), Val(2))

# Does not infer
@inferred grad(Ones(1, 1, 1), Ones(1, 1, 1), Ones(1, 2, 1), Val(2))
@inferred grad(ones(1, 1, 1), ones(1, 1, 1), Ones(1, 2, 1), Val(2))

This PR resolves the instability by always creating a UnitRange. However, I am not sure if this is sufficiently semantically close to Colon to pass muster, so RFC :) .

@mcabbott
Copy link
Member

This seems fine to me. Could you add something like this example as a test?

The rules for hcat and vcat have a similar branch, d > ndimsX ? 1 : (:). But perhaps that never matters.

@ToucheSir
Copy link
Contributor Author

The rules for hcat and vcat have a similar branch, d > ndimsX ? 1 : (:). But perhaps that never matters.

It also doesn't matter for 1 and 2D inputs here, which is why I suspect union splitting is bailing out. Cthulhu shows 4 variants of getindex(::Fill, ...) generated for the 2D case, so I guess the threshold is somewhere between 2^2 and 2^3?

@ToucheSir ToucheSir force-pushed the patch-2 branch 2 times, most recently from 76c6caa to ce98fd4 Compare April 24, 2022 04:24
@mcabbott mcabbott merged commit 7b5f4d1 into JuliaDiff:main Apr 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants