Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve the type in differentiation #149

Merged
merged 1 commit into from
Jan 24, 2020
Merged

Preserve the type in differentiation #149

merged 1 commit into from
Jan 24, 2020

Conversation

matsueushi
Copy link
Contributor

The current implementations of leakyrelu, elu and selu return Float64 gradients for Float32 inputs.

julia> using NNlib, Zygote

julia> leakyrelu'(1f0)
1.0

julia> leakyrelu'(-1f0)
0.009999999776482582

julia> elu'(1f0)
1.0

julia> elu'(-1f0)
0.3678794503211975

julia> selu'(1f0)
1.0507010221481323

julia> selu'(-1f0)
0.6467686295509338

cf. FluxML/Flux.jl#963. This PR is intended to preserve float precision for differentiation.

using NNlib, Zygote, Test

ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh];
function test_deliv_float_precision_preserving(a)
    @testset "$(a): " begin
        for T in [Float32, Float64]
            for val in [-10, -1, 0, 1, 10]
                val = @inferred a'(T(val))
                @test typeof(val) == T
            end
        end
    end
end

@testset "Float derivative inference" begin
    test_deliv_float_precision_preserving.(ACTIVATION_FUNCTIONS)
end

Before

Test Summary:              | Pass  Fail  Total
Float derivative inference |   85    15    100
  σ:                       |   10           10
  relu:                    |   10           10
  leakyrelu:               |    5     5     10
  elu:                     |    5     5     10
  gelu:                    |   10           10
  swish:                   |   10           10
  selu:                    |    5     5     10
  softplus:                |   10           10
  softsign:                |   10           10
  logcosh:                 |   10           10
ERROR: Some tests did not pass: 85 passed, 15 failed, 0 errored, 0 broken.

After

Test Summary:              | Pass  Total
Float derivative inference |  100    100

@codecov-io
Copy link

codecov-io commented Dec 18, 2019

Codecov Report

Merging #149 into master will not change coverage.
The diff coverage is 100%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #149   +/-   ##
=======================================
  Coverage   74.86%   74.86%           
=======================================
  Files          24       24           
  Lines         768      768           
=======================================
  Hits          575      575           
  Misses        193      193
Impacted Files Coverage Δ
src/activation.jl 94.11% <100%> (ø) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 136aa82...82e1fde. Read the comment docs.

@staticfloat
Copy link
Contributor

Thanks @matsueushi!

@matsueushi matsueushi deleted the diff_precision branch January 25, 2020 00:19
@DhairyaLGandhi
Copy link
Member

Should we also add the testset in our regular tests?

@matsueushi
Copy link
Contributor Author

Do I need to add the test to test/activation.jl? Currently NNlib.jl doesn't have Zygote.jl dependency. Or is it better to update the tests of Zygote.jl?

@staticfloat
Copy link
Contributor

I would add a type-stability test to NNlib (just ensure that the activation functions don't change the type unnecessarily).

@matsueushi
Copy link
Contributor Author

matsueushi commented Jan 26, 2020

If you are talking about the values of the activation functions, a test code is already defined in test/activation.jl and the previous definitions passed it. I modified it to test gradients.

function test_value_float_precision_preserving(a)
@testset "$(a): " begin
for T in [Float32, Float64]
for val in [-10, -1, 0, 1, 10]
val = @inferred a(T(val))
@test typeof(val) == T
end
end
end
end

@staticfloat
Copy link
Contributor

Right, so what I mean is that we should have that same test for the gradients of the activation functions. :)

@matsueushi
Copy link
Contributor Author

Thanks, I see what you mean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants