Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fastmath support #90

Open
chriselrod opened this issue Mar 6, 2019 · 2 comments
Open

fastmath support #90

chriselrod opened this issue Mar 6, 2019 · 2 comments
Labels
help wanted Extra attention is needed

Comments

@chriselrod
Copy link

julia> foo(x) = 2.0x
foo (generic function with 1 method)

julia> bar(x) = @fastmath 2.0x
bar (generic function with 1 method)

julia> Zygote.derivative(foo, 1e6)
2.0

julia> Zygote.derivative(bar, 1e6)
ERROR: Non-differentiable function IntrinsicFunction
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.Pullback{Tuple{Core.IntrinsicFunction,Float64,Float64},Tuple{Core.IntrinsicFunction}})(::Int8) at /home/chriselrod/.julia/packages/Zygote/wYgcz/src/compiler/interface2.jl:0
 [3] mul_fast at ./fastmath.jl:163 [inlined]
 [4] (::Zygote.Pullback{Tuple{typeof(Base.FastMath.mul_fast),Float64,Float64},Tuple{Zygote.Pullback{Tuple{Core.IntrinsicFunction,Float64,Float64},Tuple{Core.IntrinsicFunction}}}})(::Int8) at /home/chriselrod/.julia/packages/Zygote/wYgcz/src/compiler/interface2.jl:0
 [5] bar at ./REPL[55]:1 [inlined]
 [6] (::Zygote.Pullback{Tuple{typeof(bar),Float64},Tuple{getfield(Zygote, Symbol("##265#back#163")){getfield(Zygote, Symbol("#back#162")){:mul_fast,Zygote.Context,Module,typeof(Base.FastMath.mul_fast)}},Zygote.Pullback{Tuple{typeof(Base.FastMath.mul_fast),Float64,Float64},Tuple{Zygote.Pullback{Tuple{Core.IntrinsicFunction,Float64,Float64},Tuple{Core.IntrinsicFunction}}}}}})(::Int8) at /home/chriselrod/.julia/packages/Zygote/wYgcz/src/compiler/interface2.jl:0
 [7] (::getfield(Zygote, Symbol("##73#74")){Zygote.Pullback{Tuple{typeof(bar),Float64},Tuple{getfield(Zygote, Symbol("##265#back#163")){getfield(Zygote, Symbol("#back#162")){:mul_fast,Zygote.Context,Module,typeof(Base.FastMath.mul_fast)}},Zygote.Pullback{Tuple{typeof(Base.FastMath.mul_fast),Float64,Float64},Tuple{Zygote.Pullback{Tuple{Core.IntrinsicFunction,Float64,Float64},Tuple{Core.IntrinsicFunction}}}}}}})(::Int8) at /home/chriselrod/.julia/packages/Zygote/wYgcz/src/compiler/interface.jl:38
 [8] gradient(::Function, ::Float64) at /home/chriselrod/.julia/packages/Zygote/wYgcz/src/compiler/interface.jl:44
 [9] derivative(::typeof(bar), ::Float64) at /home/chriselrod/.julia/packages/Zygote/wYgcz/src/compiler/interface.jl:47
 [10] top-level scope at REPL[57]:1

@fastmath is convenient for enabling the autovectorizer and contractions (ie, fma instructions) on unrolled expressions. Would be nice to add mul_fast and friends.

@MikeInnes MikeInnes added the help wanted Extra attention is needed label Mar 7, 2019
@MikeInnes
Copy link
Member

This should be an easy contribution for anyone interested. We just need to go through the list FastMath.add_fast etc and define a simple adjoint for each one.

@Satya758
Copy link

Satya758 commented Oct 29, 2019

@MikeInnes I am picking up this issue.
Idea:
Use DiffRules to fetch differentiation rule expression for each operation then replace with fast math operator recursively if one exists (using MacroTools/@fastmath & Base.FastMath.fast_op).
Logic will look similar to https://github.com/FluxML/Zygote.jl/blob/master/src/lib/number.jl#L7

bors bot added a commit that referenced this issue Nov 4, 2019
388: Add Fastmath operators r=MikeInnes a=Satya758

Fixes issue #90 

1. Basic idea is to apply @adjoint macro to all defined fastmath operators.
2. This is done by looping over fastmath [operators](https://github.com/JuliaLang/julia/blob/master/base/fastmath.jl#L31) and then retrieving defined differentiation expression for each from [DiffRules](https://github.com/JuliaDiff/DiffRules.jl/blob/master/src/rules.jl).
3. Using differentiation expression create adjoint function simillar to https://github.com/FluxML/Zygote.jl/blob/master/src/lib/number.jl#L7

I have added bunch of test cases covering unary and binary operators treating normal operators as a expected value.
 


Co-authored-by: Satya <satyap.kommaraju@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants