Skip to content

Commit

Permalink
Add NNPACK support (#67)
Browse files Browse the repository at this point in the history
Add NNPACK support
  • Loading branch information
staticfloat authored Apr 30, 2019
2 parents 40cee4b + ee86fbb commit ec79173
Show file tree
Hide file tree
Showing 19 changed files with 670 additions and 32 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
*.dll
*~
\#*
deps/usr
deps.jl
*.log
28 changes: 28 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"

[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -17,6 +27,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

Expand All @@ -31,10 +44,18 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"

[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -45,6 +66,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"

[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

Expand All @@ -69,5 +93,9 @@ git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"

[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.6.0"

[deps]
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
2 changes: 2 additions & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
julia 1.0
Requires
MacroTools
BinaryProvider
TimerOutputs
41 changes: 41 additions & 0 deletions deps/build.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using BinaryProvider

# Parse some basic command-line arguments
const verbose = "--verbose" in ARGS
const prefix = Prefix(get([a for a in ARGS if a != "--verbose"], 1, joinpath(@__DIR__, "usr")))
products = [
LibraryProduct(prefix, ["libnnpack"], :libnnpack),
]

# Download binaries from hosted location
bin_prefix = "https://github.com/JuliaPackaging/Yggdrasil/releases/download/NNPACK-v2018.06.22-0"

# Listing of files generated by BinaryBuilder:
download_info = Dict(
Linux(:aarch64, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.aarch64-linux-gnu.tar.gz", "e0c6e21ba4c47acfd5a3d3e3510e8786474080f654338f4583b88860296c1437"),
Linux(:i686, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.i686-linux-gnu.tar.gz", "e9b6685001bc5a5d17acef15f3f6ffeb7beb6081926300f23ed4a442beac71ca"),
Linux(:i686, libc=:musl) => ("$bin_prefix/NNPACK.v2018.6.22.i686-linux-musl.tar.gz", "36c1d3c30b3bc3e0b34f215945bb46319f88e28f011fc758f21ba888b1fd9e25"),
MacOS(:x86_64) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-apple-darwin14.tar.gz", "b30046223a11470b15a2ceb0d0df6f7d8a43260fe52f4a2f8ebe5f0b2df822ca"),
Linux(:x86_64, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-linux-gnu.tar.gz", "150d5b6ca81fa72bfdc8bbda2428f0d3483fd11a5813724646c6d6c6a7ef969f"),
Linux(:x86_64, libc=:musl) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-linux-musl.tar.gz", "d961a104f814ec5b356519a82746a70a1df193ae37fc8130f38ffb61336def16"),
)

# Install unsatisfied or updated dependencies:
unsatisfied = any(!satisfied(p; verbose=verbose) for p in products)
dl_info = choose_download(download_info, platform_key_abi())
if dl_info === nothing && unsatisfied
# If we don't have a compatible .tar.gz to download, complain.
# Alternatively, you could attempt to install from a separate provider,
# build from source or something even more ambitious here.
error("Your platform (\"$(Sys.MACHINE)\", parsed as \"$(triplet(platform_key_abi()))\") is not supported by this package!")
end

# If we have a download, and we are unsatisfied (or the version we're
# trying to install is not itself installed) then load it up!
if unsatisfied || !isinstalled(dl_info...; prefix=prefix)
# Download and install binaries
install(dl_info...; prefix=prefix, force=true, verbose=verbose)
end

# Write out a deps.jl file that will contain mappings for our products
write_deps_file(joinpath(@__DIR__, "deps.jl"), products, verbose=verbose)
9 changes: 9 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@ using Requires, TimerOutputs

const to = TimerOutput()


# Include APIs
include("dim_helpers.jl")

# NNPACK support
if Sys.islinux() || Sys.isapple()
include("nnpack/NNPACK.jl")
else
is_nnpack_available() = false
end

include("activation.jl")
include("softmax.jl")
include("gemm.jl")
Expand Down
10 changes: 10 additions & 0 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,13 @@ for backend in (Symbol(), :_direct, :_im2col)
end
end
end


# Use NNPACK if it is available and the operation is supported
if is_nnpack_available()
function conv(x::Array{xT, 4}, w::Array{wT, 4},
cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
kwargs...) where {xT, wT, K, C_in, C_out, S, P, F}
return conv_nnpack(x, w, cdims; kwargs...)
end
end
16 changes: 15 additions & 1 deletion src/dim_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,18 @@ function predilate(x::AbstractArray{T,N}, dilation::NTuple{M}) where {T, N, M}
# zeros between each element of `x` along each spatial dimension.
x_dil[(1:dilation[idx]:size(x_dil,idx) for idx in 1:(N-2))..., :, :] .= x
return x_dil
end
end

"""
flipweight(w::AbstractArray)
Reorders the weight tensor for supporting both convolution and cross-correlation operations.
"""

# For any array with ndims <= 3 it makes no sense to flip the weights so simply return the
# original array
@inline flipweight(w::AbstractArray) = w

@inline flipweight(w::AbstractArray{T, 4}) where {T} = w[end:-1:1, end:-1:1, :, :]

@inline flipweight(w::AbstractArray{T, 5}) where {T} = w[end:-1:1, end:-1:1, end:-1:1, :, :]
60 changes: 60 additions & 0 deletions src/nnpack/NNPACK.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
include("libnnpack_types.jl")
include("error.jl")
include("libnnpack.jl")
include("performance.jl")
include("interface.jl")

const depsjl_path = joinpath(dirname(@__FILE__), "..", "..", "deps", "deps.jl")
if !isfile(depsjl_path)
error("NNPACK not installed properly, run Pkg.build(\"NNlib\"), restart Julia and try again")
end
include(depsjl_path)

const shared_threadpool_dict = Dict{UInt64, Base.RefValue}()

"""
is_nnpack_available()
Checks if the current hardware is supported by NNPACK.
"""
function is_nnpack_available()
check_deps()
status = nnp_initialize()
if status == nnp_status_unsupported_hardware
return false
else
return true
end
end

"""
allocate_threadpool()
Allocates several threadpool based on the upper limit on the number of threads for the machine.
Allows NNPACK to intelligently choose which threadpool to use for getting the best
performance.
"""
function allocate_threadpool()
global NNPACK_CPU_THREADS = NNPACK_CPU_THREADS > 8 ? UInt64(8) : floor(log2(NNPACK_CPU_THREADS))
for i in 1:Int(NNPACK_CPU_THREADS)
threads = UInt64(2^i)
push!(shared_threadpool_dict, threads => Ref(pthreadpool_create(threads)))
end
end

@init begin
check_deps()
status = nnp_initialize()
if status == nnp_status_unsupported_hardware
@warn "Hardware is unsupported by NNPACK so falling back to default NNlib"
end
try
global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"])
catch
# Sys.CPU_THREADS should be a better default if we are tuning the benchmark suite on
# a particular machine. However, we fix the runtime threadpool here to have a max of
# 4 threads so anything above will be ignored anyways
global NNPACK_CPU_THREADS = UInt64(4)
end
allocate_threadpool()
end
83 changes: 83 additions & 0 deletions src/nnpack/error.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
struct NNPACKError <: Exception
code::nnp_status
msg::AbstractString
end

Base.show(io::IO, err::NNPACKError) = print(io, "NNPACKError(code $(err.code), $(err.msg))")

function NNPACKError(status::nnp_status)
msg = "NNPACK STATUS SUCCESS"
if status == nnp_status_invalid_batch_size
msg = "NNPACK STATUS INVALID BATCH SIZE"
elseif status == nnp_status_invalid_channels
msg = "NNPACK STATUS INVALID CHANNELS"
elseif status == nnp_status_invalid_input_channels
msg = "NNPACK STATUS INVALID INPUT CHANNELS"
elseif status == nnp_status_invalid_output_channels
msg = "NNPACK STATUS INVALID OUTPUT CHANNELS"
elseif status == nnp_status_invalid_input_size
msg = "NNPACK STATUS INVALID INPUT SIZE"
elseif status == nnp_status_invalid_input_stride
msg = "NNPACK STATUS INVALID INPUT STRIDE"
elseif status == nnp_status_invalid_input_padding
msg = "NNPACK STATUS INVALID INPUT PADDING"
elseif status == nnp_status_invalid_kernel_size
msg = "NNPACK STATUS INVALID KERNEL SIZE"
elseif status == nnp_status_invalid_pooling_size
msg = "NNPACK STATUS INVALID POOLING SIZE"
elseif status == nnp_status_invalid_pooling_stride
msg = "NNPACK STATUS INVALID POOLING STRIDE"
elseif status == nnp_status_invalid_algorithm
msg = "NNPACK STATUS INVALID ALGORITHM"
elseif status == nnp_status_invalid_transform_strategy
msg = "NNPACK STATUS INVALID TRANSFORM STRATEGY"
elseif status == nnp_status_invalid_output_subsampling
msg = "NNPACK STATUS INVALID OUTPUT SUBSAMPLING"
elseif status == nnp_status_invalid_activation
msg = "NNPACK STATUS INVALID ACTIVATION"
elseif status == nnp_status_invalid_activation_parameters
msg = "NNPACK STATUS INVALID ACTIVATION PARAMETERS"
elseif status == nnp_status_unsupported_input_size
msg = "NNPACK STATUS UNSUPPORTED INPUT SIZE"
elseif status == nnp_status_unsupported_input_stride
msg = "NNPACK STATUS UNSUPPORTED INPUT STRIDE"
elseif status == nnp_status_unsupported_input_padding
msg = "NNPACK STATUS UNSUPPORTED INPUT PADDING"
elseif status == nnp_status_unsupported_kernel_size
msg = "NNPACK STATUS UNSUPPORTED KERNEL SIZE"
elseif status == nnp_status_unsupported_pooling_size
msg = "NNPACK STATUS UNSUPPORTED POOLING SIZE"
elseif status == nnp_status_unsupported_pooling_stride
msg = "NNPACK STATUS UNSUPPORTED POOLING STRIDE"
elseif status == nnp_status_unsupported_algorithm
msg = "NNPACK STATUS UNSUPPORTED ALGORITHM"
elseif status == nnp_status_unsupported_transform_strategy
msg = "NNPACK STATUS UNSUPPORTED TRANSFORM STRATEGY"
elseif status == nnp_status_unsupported_activation
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION"
elseif status == nnp_status_unsupported_activation_parameters
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION PARAMETERS"
elseif status == nnp_status_uninitialized
msg = "NNPACK STATUS UNINITIALIZED"
elseif status == nnp_status_unsupported_hardware
msg = "NNPACK STATUS UNSUPPORTED HARDWARE"
elseif status == nnp_status_out_of_memory
msg = "NNPACK STATUS OUT OF MEMORY"
elseif status == nnp_status_insufficient_buffer
msg = "NNPACK STATUS INSUFFICIENT BUFFER"
elseif status == nnp_status_misaligned_buffer
msg = "NNPACK STATUS MISALIGNED BUFFER"
end
NNPACKError(status, msg)
end

macro nnpack_check(nnp_func)
quote
local err::nnp_status
err = $(esc(nnp_func))
if err != nnp_status_success
throw(NNPACKError(err))
end
err
end
end
50 changes: 50 additions & 0 deletions src/nnpack/impl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
function maxpool_nnpack!(y::A, x::A, pdims::PoolDims) where {A<:Array{Float32, 4}}
check_dims(size(x), size(y), pdims)
threadpool = select_threadpool(pdims, size(y, 4))
nnp_max_pooling_output(y, x, kernel_size(pdims), padding = padding(pdims),
stride = stride(pdims), threadpool = threadpool)
end

@timeit_debug to function conv_nnpack!(y::A1, x::A1, w::A1, cdims::ConvDims;
b::A2 = zeros(Float32, size(x, 3)),
algo = UInt32(0)) where {A1<:Array{Float32, 4},
A2<:Array{Float32, 1}}
check_dims(size(x), size(w), size(y), cdims)
threadpool = select_threadpool(cdims, size(y, 4))

if flipkernel(cdims) == 0
w .= flipweight(w)
end

nnp_convolution_output(y, x, w, b, algo = algo, padding = padding(cdims),
stride = stride(cdims), threadpool = threadpool)
end

@timeit_debug to function ∇conv_data_nnpack!(dx::A, dy::A, w::A, cdims::ConvDims;
algo = UInt32(0)) where{A<:Array{Float32, 4}}
check_dims(size(dx), size(w), size(dy), cdims)
threadpool = select_threadpool(cdims, size(y, 4))

if flipkernel(cdims) == 0
w .= flipweight(w)
end

nnp_convolution_input_gradient(dx, dy, w, algo = algo, padding = padding(cdims),
stride = stride(cdims), threadpool = threadpool)
end

@timeit_debug to function ∇conv_filter_nnpack!(dw::A, x::A, dy::A, cdims::ConvDims;
algo = UInt32(0)) where{A<:Array{Float32, 4}}
check_dims(size(x), size(dw), size(dy), cdims)
threadpool = select_threadpool(cdims, size(y, 4))

nnp_convolution_kernel_gradient(dw, x, dy, algo = algo, padding = padding(cdims),
stride = stride(cdims), threadpool = threadpool)

if flipkernel(cdims) == 0
dw .= flipweight(dw)
end

dw
end

Loading

0 comments on commit ec79173

Please sign in to comment.