Skip to content

Commit 60b2b92

Browse files
authored
Merge pull request #94 from FluxML/sf/overhaul
Major overhaul of NNlib
2 parents 11f840d + 936e71a commit 60b2b92

35 files changed

+3726
-1767
lines changed

.travis.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ notifications:
99
email: false
1010
git:
1111
depth: 99999999
12+
env:
13+
# Disable test fuzzing for the moment, as we're a little too slow for Travis
14+
- NNLIB_TEST_FUZZING=false
1215

1316
# Submit to Codecov
1417
after_success:

Manifest.toml

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,22 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
13
[[Base64]]
24
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
35

4-
[[Compat]]
5-
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
6-
git-tree-sha1 = "ff2595695fc4f14427358ce2593f867085c45dcb"
7-
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
8-
version = "1.2.0"
9-
10-
[[Dates]]
11-
deps = ["Printf"]
12-
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
13-
14-
[[DelimitedFiles]]
15-
deps = ["Mmap"]
16-
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
6+
[[Crayons]]
7+
deps = ["Test"]
8+
git-tree-sha1 = "3017c662a988bcb8a3f43306a793617c6524d476"
9+
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
10+
version = "1.0.0"
1711

1812
[[Distributed]]
19-
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
13+
deps = ["Random", "Serialization", "Sockets"]
2014
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2115

2216
[[InteractiveUtils]]
23-
deps = ["LinearAlgebra", "Markdown"]
17+
deps = ["Markdown"]
2418
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
2519

26-
[[LibGit2]]
27-
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
28-
2920
[[Libdl]]
3021
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
3122

@@ -36,31 +27,14 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3627
[[Logging]]
3728
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
3829

39-
[[MacroTools]]
40-
deps = ["Compat"]
41-
git-tree-sha1 = "c443e1c8d58a4e9f61b708ad0a88286c7042145b"
42-
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
43-
version = "0.4.4"
44-
4530
[[Markdown]]
4631
deps = ["Base64"]
4732
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
4833

49-
[[Mmap]]
50-
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
51-
52-
[[Pkg]]
53-
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
54-
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
55-
5634
[[Printf]]
5735
deps = ["Unicode"]
5836
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
5937

60-
[[REPL]]
61-
deps = ["InteractiveUtils", "Markdown", "Sockets"]
62-
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
63-
6438
[[Random]]
6539
deps = ["Serialization"]
6640
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -71,16 +45,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
7145
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
7246
version = "0.5.2"
7347

74-
[[SHA]]
75-
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
76-
7748
[[Serialization]]
7849
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
7950

80-
[[SharedArrays]]
81-
deps = ["Distributed", "Mmap", "Random", "Serialization"]
82-
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
83-
8451
[[Sockets]]
8552
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
8653

@@ -96,9 +63,11 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
9663
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
9764
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9865

99-
[[UUIDs]]
100-
deps = ["Random"]
101-
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
66+
[[TimerOutputs]]
67+
deps = ["Crayons", "Printf", "Test", "Unicode"]
68+
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
69+
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
70+
version = "0.5.0"
10271

10372
[[Unicode]]
10473
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
44
[deps]
55
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7-
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
87
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
8+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
99
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10+
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

REQUIRE

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
julia 0.7-
1+
julia 1.0
22
Requires
3-
MacroTools

src/NNlib.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
module NNlib
2+
using Requires, TimerOutputs
23

3-
using Requires, Libdl
4-
5-
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ, logsigmoid,
6-
softmax, logsoftmax, maxpool, meanpool
7-
8-
include("numeric.jl")
4+
# Include APIs
5+
include("dim_helpers.jl")
96
include("activation.jl")
107
include("softmax.jl")
11-
include("logsoftmax.jl")
12-
include("linalg.jl")
8+
include("gemm.jl")
139
include("conv.jl")
14-
include("cubroadcast.jl")
10+
include("pooling.jl")
11+
12+
## Include implementations
13+
include("impl/padding_edges.jl")
14+
15+
# Direct implementations of convolutional and depthwise-convolutional algorithms
16+
include("impl/conv_direct.jl")
17+
include("impl/depthwiseconv_direct.jl")
18+
# im2col implementations of convolutional and depthwise-convolutional algorithms
19+
include("impl/conv_im2col.jl")
20+
include("impl/depthwiseconv_im2col.jl")
21+
22+
# Direct implementations of pooling
23+
include("impl/pooling_direct.jl")
24+
25+
to = TimerOutput()
1526

16-
end # module
27+
end # module NNlib

src/activation.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1+
export σ, sigmoid, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logσ,
2+
logsigmoid
3+
14
"""
25
σ(x) = 1 / (1 + exp(-x))
36
47
Classic [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) activation
58
function.
69
"""
710
σ(x) = one(x) / (one(x) + exp(-x))
8-
911
const sigmoid = σ
1012

1113
# ForwardDiff numerical stability hack
1214
σ_stable(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
13-
1415
σ(x::Float32) = σ_stable(x)
15-
1616
@init @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
1717
σ(x::ForwardDiff.Dual{T,Float32}) where T = σ_stable(x)
1818
end
1919

20+
2021
"""
2122
logσ(x)
2223
@@ -31,13 +32,13 @@ Return `log(σ(x))` which is computed in a numerically stable way.
3132
-0.0
3233
"""
3334
function logσ(x)
34-
max_v = max(zero(x), -x)
35-
z = exp(-max_v) + exp(-x-max_v)
36-
-(max_v + log(z))
35+
max_v = max(zero(x), -x)
36+
z = exp(-max_v) + exp(-x-max_v)
37+
return -(max_v + log(z))
3738
end
38-
3939
const logsigmoid = logσ
4040

41+
4142
"""
4243
relu(x) = max(0, x)
4344
@@ -56,6 +57,7 @@ You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
5657
"""
5758
leakyrelu(x, a = oftype(x/1, 0.01)) = max(a*x, x/1)
5859

60+
5961
"""
6062
elu(x, α = 1) =
6163
x > 0 ? x : α * (exp(x) - 1)
@@ -66,6 +68,7 @@ You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
6668
"""
6769
elu(x, α = one(x)) = ifelse(x 0, x/1, α * (exp(x) - one(x)))
6870

71+
6972
"""
7073
gelu(x) = 0.5x*(1 + tanh(√(2/π)*(x + 0.044715x^3)))
7174
@@ -103,6 +106,7 @@ function selu(x)
103106
λ * ifelse(x > 0, x/1, α * (exp(x) - 1))
104107
end
105108

109+
106110
"""
107111
softsign(x) = x / (1 + |x|)
108112

0 commit comments

Comments
 (0)