-
-
Notifications
You must be signed in to change notification settings - Fork 128
Improve activation functions, make fast versions accessible #371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Well done! This looks great. |
Takes an optional 2nd argument, so that you can disable | ||
this replacement for some array or element types. | ||
""" | ||
@inline fast_act(f::F, ::AbstractArray = 1:0) where {F<:Function} = f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👀 I've never seen this ::AbstractArray = 1:0
pattern before, what does it do? Edit: figured that part out, is there a benefit to 1:0
over 0:0
or 0:1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking is that we may later decide that it's better to skip tanh_fast
on the GPU. I can't measure a difference so who knows. To do that, we can add a method in NNlibCUDA like fast_act(::typeof(tanh), ::CuArray) = tanh
. Provided you call this like fast = NNlib.fast_act(fun, x)
with an example of the array you plan to broadcast it over.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My branch for this is mcabbott/Flux.jl@217eb1a
which I see I did on top of FluxML/Flux.jl#1761
Flux is not happy: https://github.com/FluxML/Flux.jl/runs/4770890956?check_suite_focus=true. I think it's time to look into reverse CI. |
Oh no, sorry. Probably not hard to fix though:
Either we change the supertype, or we widen the method. |
cc @darsnack for his thoughts. |
I think we tried making Widening |
Yes I remember something about AbstractFloat causing problems. But does Real cause problems? Should probably widen here anyway, now that I think about it, in case someone is crazy enough to use Flux with complex numbers. |
Although in fact, now that I look at the file, few activation functions will work on |
Yeah, just tested that subtyping Ideally, |
This upgrades many of the activation functions. For instance, the new
gelu
is about 7x faster (when broadcast), and has worst-case errornextfloat(true, 9)
instead of 7675 before. You can check my working below if so inclined.It also defines a function which does
fast_act(tanh) == tanh_fast
etc, to allow Flux to automatically substitute the faster-than-Base variants. For all excepttanh
andsigmoid
, however, there seems to be no point in keeping the slow versions around at all.The exception is
softplus
, for which a fast version is surely possible, but not worked out yet, so not included here.Paging @oscardssmith, since we discussed things in #345