Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving Buffer to ZygoteRules #6

Open
sethaxen opened this issue Oct 15, 2019 · 7 comments · May be fixed by FluxML/Zygote.jl#645
Open

Moving Buffer to ZygoteRules #6

sethaxen opened this issue Oct 15, 2019 · 7 comments · May be fixed by FluxML/Zygote.jl#645

Comments

@sethaxen
Copy link

Is Zygote.Buffer too hefty to move to ZygoteRules? I work with several packages that use mutation extensively, and it would be nice to add custom rules using Buffer without adding Zygote as a full dependency.

@MikeInnes
Copy link
Member

I think it might make sense to make Buffers.jl a separate package. If you're willing to put it together, we can move it into FluxML, tag etc.

@sethaxen
Copy link
Author

sethaxen commented Nov 9, 2019

Sounds good! Happy to do it.

@sethaxen
Copy link
Author

sethaxen commented Nov 9, 2019

Here's an initial repo with a failing test: https://github.com/sethaxen/Buffers.jl. Buffer's adjoints require Zygote.grad_mut, which is dependent on Zygote.Context and Zygote.cache, not just ZygoteRules.AContext I don't really understand how grad_mut works. Any advice for how to replace?

@MikeInnes
Copy link
Member

grad_mut is just a convenience that checks for a cached gradient for the given object, and creates one if it needs to. You could just copy-paste whatever definitions you need to make it work.

@sethaxen
Copy link
Author

Sorry it took a while to get back to this. I've finished porting the Buffer-related code and tests to https://github.com/sethaxen/Buffers.jl, and the tests pass. However, after loading Buffers, Buffer seems to get pulled into adjoints that have nothing to do with it. example with these functions from Zygote's tests:

function pow_mut(x, n)
  r = Ref(one(x))
  while n > 0
    n -= 1
    r[] = r[] * x
  end
  return r[]
end

struct Foo{T}
  a::T
  b::T
end

function mul_struct(a, b)
  c = Foo(a, b)
  c.a * c.b
end

kwmul(; a = 1, b) = a*b

mul_kw(a, b) = kwmul(a = a, b = b)
julia> using Zygote

julia> gradient(pow_mut, 2, 3)
(nothing, 12)

julia> gradient(mul_struct, 2, 3)
(3, 2)

julia> gradient(mul_kw, 2, 3)
(3, 2)

julia> using Buffers

julia> gradient(pow_mut, 2, 3)
ERROR: MethodError: no method matching Buffer(::Int64)
Closest candidates are:
  Buffer(::A, ::Bool) where {T, A<:(AbstractArray{T,N} where N)} at /Users/saxen/projects/Buffers.jl/src/buffer.jl:38
  Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
 [1] _pullback(::Zygote.Context, ::UnionAll, ::Int64) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
 [2] pow_mut at ./REPL[1]:2 [inlined]
 [3] _pullback(::Zygote.Context, ::typeof(pow_mut), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
 [5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
 [6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
 [7] top-level scope at REPL[11]:1

julia> gradient(mul_struct, 2, 3)
ERROR: MethodError: no method matching Buffer(::Int64, ::Int64)
Closest candidates are:
  Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
 [1] _pullback(::Zygote.Context, ::UnionAll, ::Int64, ::Int64) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
 [2] mul_struct at ./REPL[3]:2 [inlined]
 [3] _pullback(::Zygote.Context, ::typeof(mul_struct), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
 [5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
 [6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
 [7] top-level scope at REPL[12]:1

julia> gradient(mul_kw, 2, 3)
ERROR: MethodError: no method matching Buffer(::Tuple{Int64,Int64})
Closest candidates are:
  Buffer(::A, ::Bool) where {T, A<:(AbstractArray{T,N} where N)} at /Users/saxen/projects/Buffers.jl/src/buffer.jl:38
  Buffer(::AbstractArray, ::Any...) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:42
Stacktrace:
 [1] _pullback(::Zygote.Context, ::UnionAll, ::Tuple{Int64,Int64}) at /Users/saxen/projects/Buffers.jl/src/buffer.jl:91
 [2] mul_kw at ./REPL[5]:1 [inlined]
 [3] _pullback(::Zygote.Context, ::typeof(mul_kw), ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface2.jl:0
 [4] _pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:29
 [5] pullback(::Function, ::Int64, ::Int64) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:35
 [6] gradient(::Function, ::Int64, ::Vararg{Int64,N} where N) at /Users/saxen/.julia/packages/Zygote/4tJp5/src/compiler/interface.jl:44
 [7] top-level scope at REPL[13]:1

Do you have any idea what can be going wrong?

@MikeInnes
Copy link
Member

These methods are all type-piratical. Not sure if that's the direct cause but would be worth fixing.

@sethaxen
Copy link
Author

These methods are all type-piratical. Not sure if that's the direct cause but would be worth fixing.

Those are all created in Zygote, but this package only depends on ZygoteRules, so I don't see how they can be type-pirating Zygote. Besides, changing the names completely produces the same errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants