-
-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NNPACK support
- Loading branch information
Showing
19 changed files
with
670 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,6 @@ | |
*.dll | ||
*~ | ||
\#* | ||
deps/usr | ||
deps.jl | ||
*.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
julia 1.0 | ||
Requires | ||
MacroTools | ||
BinaryProvider | ||
TimerOutputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
using BinaryProvider | ||
|
||
# Parse some basic command-line arguments | ||
const verbose = "--verbose" in ARGS | ||
const prefix = Prefix(get([a for a in ARGS if a != "--verbose"], 1, joinpath(@__DIR__, "usr"))) | ||
products = [ | ||
LibraryProduct(prefix, ["libnnpack"], :libnnpack), | ||
] | ||
|
||
# Download binaries from hosted location | ||
bin_prefix = "https://github.com/JuliaPackaging/Yggdrasil/releases/download/NNPACK-v2018.06.22-0" | ||
|
||
# Listing of files generated by BinaryBuilder: | ||
download_info = Dict( | ||
Linux(:aarch64, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.aarch64-linux-gnu.tar.gz", "e0c6e21ba4c47acfd5a3d3e3510e8786474080f654338f4583b88860296c1437"), | ||
Linux(:i686, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.i686-linux-gnu.tar.gz", "e9b6685001bc5a5d17acef15f3f6ffeb7beb6081926300f23ed4a442beac71ca"), | ||
Linux(:i686, libc=:musl) => ("$bin_prefix/NNPACK.v2018.6.22.i686-linux-musl.tar.gz", "36c1d3c30b3bc3e0b34f215945bb46319f88e28f011fc758f21ba888b1fd9e25"), | ||
MacOS(:x86_64) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-apple-darwin14.tar.gz", "b30046223a11470b15a2ceb0d0df6f7d8a43260fe52f4a2f8ebe5f0b2df822ca"), | ||
Linux(:x86_64, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-linux-gnu.tar.gz", "150d5b6ca81fa72bfdc8bbda2428f0d3483fd11a5813724646c6d6c6a7ef969f"), | ||
Linux(:x86_64, libc=:musl) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-linux-musl.tar.gz", "d961a104f814ec5b356519a82746a70a1df193ae37fc8130f38ffb61336def16"), | ||
) | ||
|
||
# Install unsatisfied or updated dependencies: | ||
unsatisfied = any(!satisfied(p; verbose=verbose) for p in products) | ||
dl_info = choose_download(download_info, platform_key_abi()) | ||
if dl_info === nothing && unsatisfied | ||
# If we don't have a compatible .tar.gz to download, complain. | ||
# Alternatively, you could attempt to install from a separate provider, | ||
# build from source or something even more ambitious here. | ||
error("Your platform (\"$(Sys.MACHINE)\", parsed as \"$(triplet(platform_key_abi()))\") is not supported by this package!") | ||
end | ||
|
||
# If we have a download, and we are unsatisfied (or the version we're | ||
# trying to install is not itself installed) then load it up! | ||
if unsatisfied || !isinstalled(dl_info...; prefix=prefix) | ||
# Download and install binaries | ||
install(dl_info...; prefix=prefix, force=true, verbose=verbose) | ||
end | ||
|
||
# Write out a deps.jl file that will contain mappings for our products | ||
write_deps_file(joinpath(@__DIR__, "deps.jl"), products, verbose=verbose) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
include("libnnpack_types.jl") | ||
include("error.jl") | ||
include("libnnpack.jl") | ||
include("performance.jl") | ||
include("interface.jl") | ||
|
||
const depsjl_path = joinpath(dirname(@__FILE__), "..", "..", "deps", "deps.jl") | ||
if !isfile(depsjl_path) | ||
error("NNPACK not installed properly, run Pkg.build(\"NNlib\"), restart Julia and try again") | ||
end | ||
include(depsjl_path) | ||
|
||
const shared_threadpool_dict = Dict{UInt64, Base.RefValue}() | ||
|
||
""" | ||
is_nnpack_available() | ||
Checks if the current hardware is supported by NNPACK. | ||
""" | ||
function is_nnpack_available() | ||
check_deps() | ||
status = nnp_initialize() | ||
if status == nnp_status_unsupported_hardware | ||
return false | ||
else | ||
return true | ||
end | ||
end | ||
|
||
""" | ||
allocate_threadpool() | ||
Allocates several threadpool based on the upper limit on the number of threads for the machine. | ||
Allows NNPACK to intelligently choose which threadpool to use for getting the best | ||
performance. | ||
""" | ||
function allocate_threadpool() | ||
global NNPACK_CPU_THREADS = NNPACK_CPU_THREADS > 8 ? UInt64(8) : floor(log2(NNPACK_CPU_THREADS)) | ||
for i in 1:Int(NNPACK_CPU_THREADS) | ||
threads = UInt64(2^i) | ||
push!(shared_threadpool_dict, threads => Ref(pthreadpool_create(threads))) | ||
end | ||
end | ||
|
||
@init begin | ||
check_deps() | ||
status = nnp_initialize() | ||
if status == nnp_status_unsupported_hardware | ||
@warn "Hardware is unsupported by NNPACK so falling back to default NNlib" | ||
end | ||
try | ||
global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"]) | ||
catch | ||
# Sys.CPU_THREADS should be a better default if we are tuning the benchmark suite on | ||
# a particular machine. However, we fix the runtime threadpool here to have a max of | ||
# 4 threads so anything above will be ignored anyways | ||
global NNPACK_CPU_THREADS = UInt64(4) | ||
end | ||
allocate_threadpool() | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
struct NNPACKError <: Exception | ||
code::nnp_status | ||
msg::AbstractString | ||
end | ||
|
||
Base.show(io::IO, err::NNPACKError) = print(io, "NNPACKError(code $(err.code), $(err.msg))") | ||
|
||
function NNPACKError(status::nnp_status) | ||
msg = "NNPACK STATUS SUCCESS" | ||
if status == nnp_status_invalid_batch_size | ||
msg = "NNPACK STATUS INVALID BATCH SIZE" | ||
elseif status == nnp_status_invalid_channels | ||
msg = "NNPACK STATUS INVALID CHANNELS" | ||
elseif status == nnp_status_invalid_input_channels | ||
msg = "NNPACK STATUS INVALID INPUT CHANNELS" | ||
elseif status == nnp_status_invalid_output_channels | ||
msg = "NNPACK STATUS INVALID OUTPUT CHANNELS" | ||
elseif status == nnp_status_invalid_input_size | ||
msg = "NNPACK STATUS INVALID INPUT SIZE" | ||
elseif status == nnp_status_invalid_input_stride | ||
msg = "NNPACK STATUS INVALID INPUT STRIDE" | ||
elseif status == nnp_status_invalid_input_padding | ||
msg = "NNPACK STATUS INVALID INPUT PADDING" | ||
elseif status == nnp_status_invalid_kernel_size | ||
msg = "NNPACK STATUS INVALID KERNEL SIZE" | ||
elseif status == nnp_status_invalid_pooling_size | ||
msg = "NNPACK STATUS INVALID POOLING SIZE" | ||
elseif status == nnp_status_invalid_pooling_stride | ||
msg = "NNPACK STATUS INVALID POOLING STRIDE" | ||
elseif status == nnp_status_invalid_algorithm | ||
msg = "NNPACK STATUS INVALID ALGORITHM" | ||
elseif status == nnp_status_invalid_transform_strategy | ||
msg = "NNPACK STATUS INVALID TRANSFORM STRATEGY" | ||
elseif status == nnp_status_invalid_output_subsampling | ||
msg = "NNPACK STATUS INVALID OUTPUT SUBSAMPLING" | ||
elseif status == nnp_status_invalid_activation | ||
msg = "NNPACK STATUS INVALID ACTIVATION" | ||
elseif status == nnp_status_invalid_activation_parameters | ||
msg = "NNPACK STATUS INVALID ACTIVATION PARAMETERS" | ||
elseif status == nnp_status_unsupported_input_size | ||
msg = "NNPACK STATUS UNSUPPORTED INPUT SIZE" | ||
elseif status == nnp_status_unsupported_input_stride | ||
msg = "NNPACK STATUS UNSUPPORTED INPUT STRIDE" | ||
elseif status == nnp_status_unsupported_input_padding | ||
msg = "NNPACK STATUS UNSUPPORTED INPUT PADDING" | ||
elseif status == nnp_status_unsupported_kernel_size | ||
msg = "NNPACK STATUS UNSUPPORTED KERNEL SIZE" | ||
elseif status == nnp_status_unsupported_pooling_size | ||
msg = "NNPACK STATUS UNSUPPORTED POOLING SIZE" | ||
elseif status == nnp_status_unsupported_pooling_stride | ||
msg = "NNPACK STATUS UNSUPPORTED POOLING STRIDE" | ||
elseif status == nnp_status_unsupported_algorithm | ||
msg = "NNPACK STATUS UNSUPPORTED ALGORITHM" | ||
elseif status == nnp_status_unsupported_transform_strategy | ||
msg = "NNPACK STATUS UNSUPPORTED TRANSFORM STRATEGY" | ||
elseif status == nnp_status_unsupported_activation | ||
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION" | ||
elseif status == nnp_status_unsupported_activation_parameters | ||
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION PARAMETERS" | ||
elseif status == nnp_status_uninitialized | ||
msg = "NNPACK STATUS UNINITIALIZED" | ||
elseif status == nnp_status_unsupported_hardware | ||
msg = "NNPACK STATUS UNSUPPORTED HARDWARE" | ||
elseif status == nnp_status_out_of_memory | ||
msg = "NNPACK STATUS OUT OF MEMORY" | ||
elseif status == nnp_status_insufficient_buffer | ||
msg = "NNPACK STATUS INSUFFICIENT BUFFER" | ||
elseif status == nnp_status_misaligned_buffer | ||
msg = "NNPACK STATUS MISALIGNED BUFFER" | ||
end | ||
NNPACKError(status, msg) | ||
end | ||
|
||
macro nnpack_check(nnp_func) | ||
quote | ||
local err::nnp_status | ||
err = $(esc(nnp_func)) | ||
if err != nnp_status_success | ||
throw(NNPACKError(err)) | ||
end | ||
err | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
function maxpool_nnpack!(y::A, x::A, pdims::PoolDims) where {A<:Array{Float32, 4}} | ||
check_dims(size(x), size(y), pdims) | ||
threadpool = select_threadpool(pdims, size(y, 4)) | ||
nnp_max_pooling_output(y, x, kernel_size(pdims), padding = padding(pdims), | ||
stride = stride(pdims), threadpool = threadpool) | ||
end | ||
|
||
@timeit_debug to function conv_nnpack!(y::A1, x::A1, w::A1, cdims::ConvDims; | ||
b::A2 = zeros(Float32, size(x, 3)), | ||
algo = UInt32(0)) where {A1<:Array{Float32, 4}, | ||
A2<:Array{Float32, 1}} | ||
check_dims(size(x), size(w), size(y), cdims) | ||
threadpool = select_threadpool(cdims, size(y, 4)) | ||
|
||
if flipkernel(cdims) == 0 | ||
w .= flipweight(w) | ||
end | ||
|
||
nnp_convolution_output(y, x, w, b, algo = algo, padding = padding(cdims), | ||
stride = stride(cdims), threadpool = threadpool) | ||
end | ||
|
||
@timeit_debug to function ∇conv_data_nnpack!(dx::A, dy::A, w::A, cdims::ConvDims; | ||
algo = UInt32(0)) where{A<:Array{Float32, 4}} | ||
check_dims(size(dx), size(w), size(dy), cdims) | ||
threadpool = select_threadpool(cdims, size(y, 4)) | ||
|
||
if flipkernel(cdims) == 0 | ||
w .= flipweight(w) | ||
end | ||
|
||
nnp_convolution_input_gradient(dx, dy, w, algo = algo, padding = padding(cdims), | ||
stride = stride(cdims), threadpool = threadpool) | ||
end | ||
|
||
@timeit_debug to function ∇conv_filter_nnpack!(dw::A, x::A, dy::A, cdims::ConvDims; | ||
algo = UInt32(0)) where{A<:Array{Float32, 4}} | ||
check_dims(size(x), size(dw), size(dy), cdims) | ||
threadpool = select_threadpool(cdims, size(y, 4)) | ||
|
||
nnp_convolution_kernel_gradient(dw, x, dy, algo = algo, padding = padding(cdims), | ||
stride = stride(cdims), threadpool = threadpool) | ||
|
||
if flipkernel(cdims) == 0 | ||
dw .= flipweight(dw) | ||
end | ||
|
||
dw | ||
end | ||
|
Oops, something went wrong.