Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map overreach? #646

Open
willtebbutt opened this issue May 14, 2020 · 3 comments
Open

map overreach? #646

willtebbutt opened this issue May 14, 2020 · 3 comments

Comments

@willtebbutt
Copy link
Member

Zygote's current map implementation is arguably a bit optimistic about the types of things that it's able to handle.

For example, this issue in KernelFunctions.jl cropped up because we define a custom AbstractVector type that wraps a Matrix, and lets it masquerade as a vector-of-vectors.

Under the hood, this type makes sure to implement various operations efficiently on the wrapped matrix. It would be reasonable to assume that Zygote would be able to exploit these efficient implementations (because composition), but instead it hits the map adjoint and literally treats the object as a vector-of-vectors, which is bad for performance.

I would propose to impose further type constraints on the implementation of map, perhaps to StridedArray or DenseArray, whichever is deemed a better target. @MikeInnes @dhairyagandhi96 any thoughts?

@MikeInnes
Copy link
Member

Sounds fine to me. This is kind of a tricky tradeoff unfortunately; probably the only real answer is to delete the adjoint entirely and support differentiating through map, but that won't work right now.

@willtebbutt
Copy link
Member Author

I kind of agree. There's definitely something to be exploited in knowing that map acts independently on each of the elements of its input, be it in compile times or run-time performance, so writing custom rules feels to me like something of a no-brainer here -- it's just a question of the level of generality at which you implement them.

I'll try to remember to make a PR on this soon.

@ToucheSir
Copy link
Member

Perhaps we could kill two birds with one stone if JuliaDiff/ChainRules.jl#314 gets implemented. Moving to ChainRules would also get us ProjectTo, which in theory could handle more array types (modulo efficiency concerns).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants