Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rrules do not support chunked mode #566

Open
oscardssmith opened this issue Jul 18, 2022 · 2 comments
Open

rrules do not support chunked mode #566

oscardssmith opened this issue Jul 18, 2022 · 2 comments
Labels
design Requires some desgin before changes are made documentation Improvements or additions to documentation help wanted Extra attention is needed

Comments

@oscardssmith
Copy link
Member

oscardssmith commented Jul 18, 2022

This is a general issue, but for a specific incarnation, https://github.com/JuliaDiff/ChainRules.jl/blob/8073c7c4638bdd46f4e822d2ab72423c051c5e4b/src/rulesets/Base/array.jl#L40

function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
    vect_pullback(ȳ) = (NoTangent(), NTuple{N}(ȳ)...)
    return Base.vect(X...), vect_pullback
end

This rule implicitly assumes that is a Vector, but if you are taking a jacobian, it will be a Matrix in which case, it should be

function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
    vect_pullback(ȳ) = (NoTangent(), ȳ...)
    return Base.vect(X...), vect_pullback
end

Similar problems also exist for the getindex rrules, and I'm sure there are a bunch of other similar cases.
Is there a good general solution to this?

@mcabbott
Copy link
Member

mcabbott commented Jul 19, 2022

I think you're asking whether there's a scheme for chunked reverse mode. There is not: at present (co)tangents match the size of the primal. #92 has some discussion, see also JuliaDiff/Diffractor.jl#54.

Edit: most rules will enforce this via projection:

julia> x = [1,2,3];  # vector primal

julia> ProjectTo(x)([4;5;6;;])  # allows 1-column matrix, converts to vector
3-element Vector{Float64}:
 4.0
 5.0
 6.0

julia> ProjectTo(x)([4 5 6])  # does not allow worse shapes
ERROR: DimensionMismatch: variable with size(x) == (3,) cannot have a gradient with size(dx) == (1, 3)

@mcabbott mcabbott changed the title Incorrect rrules for vector valued functions rrules do not support chunked mode Jul 19, 2022
@mcabbott
Copy link
Member

For now the current status should be clearly documented, perhaps at these pages:

https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/tangents.html

https://juliadiff.org/ChainRulesCore.jl/dev/maths/propagators.html

@mcabbott mcabbott added documentation Improvements or additions to documentation help wanted Extra attention is needed design Requires some desgin before changes are made labels Jul 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some desgin before changes are made documentation Improvements or additions to documentation help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants