Skip to content

Commit

Permalink
[Mobile GPU][Integration] Vulkan backend integration (pytorch#36491)
Browse files Browse the repository at this point in the history
Summary:
This PR contains the initial version of Vulkan (GPU) Backend integration.
The primary target environment is Android, but the desktop build is also supported.

## CMake
Introducing three cmake options:
USE_VULKAN:
The main switch, if it is off, all other options do not affect.
USE_VULKAN_WRAPPER:
ON - Vulkan will be used loading it at runtime as "libvulkan.so" using libdl, every function call is wrapped in vulkan_wrapper.h.
OFF - linking with libvulkan.so directly
USE_VULKAN_SHADERC_RUNTIME:
ON - Shader compilation library will be linked, and shaders will be compiled runtime.
OFF - Shaders will be precompiled and shader compilation library is not included.

## Codegen
if `USE_VULKAN_SHADERC_RUNTIME` is ON:
Shaders precompilation () starts in cmake/VulkanCodegen.cmake, which calls `aten/src/ATen/native/vulkan/gen_glsl.py` or `aten/src/ATen/native/vulkan/gen_spv.py` to include shaders source or SPIR-V bytecode inside binary as uint32_t array in spv.h,spv.cpp.
if `USE_VULKAN_SHADERC_RUNTIME` is OFF:
The source of shaders is included as `glsl.h`,`glsl.cpp`.

All codegen results happen in the build directory.

## Build dependencies
cmake/Dependencies.cmake
If the target platform is Android - vulkan library, headers, Vulkan wrapper will be used from ANDROID_NDK.
Desktop build requires the VULKAN_SDK environment variable, and all vulkan dependencies will be used from it.
(Desktop build was tested only on Linux).

## Pytorch integration:
Adding 'Vulkan" as new Backend, DispatchKey, DeviceType.
We are using Strided layout without supporting strides at the moment, but we plan to support them in the future.
Using OpaqueTensorImpl where OpaqueHandle is copyable VulkanTensor,
more details in comments in `aten/src/ATen/native/vulkan/Vulkan.h`

Main code location: `aten/src/ATen/native/vulkan`
`aten/src/ATen/native/vulkan/VulkanAten.cpp` - connection link between ATen and Vulkan api (Vulkan.h) that converts at::Tensor to VulkanTensor.

`aten/src/ATen/native/Vulkan/Vulkan.h` - Vulkan API that contains VulkanTensor representation and functions to work with it. Plan to expose it for clients to be able to write their own Vulkan Ops.

`aten/src/ATen/native/vulkan/VulkanOps.cpp` - Vulkan Operations Implementations that uses Vulkan.h API

## GLSL shaders
Located in `aten/src/ATen/native/vulkan/glsl` as *.glsl files.
All shaders use Vulkan specialized constants for workgroup sizes with ids 1, 2, 3

## Supported operations
Code point:
conv2d no-groups
conv2d depthwise
addmm
upsample nearest 2d
clamp
hardtanh

## Testing
`aten/src/ATen/test/vulkan_test.cpp` - contains tests for
copy from CPU to Vulkan and back
all supported operations
Desktop builds supported, and testing can be done on a desktop that has Vulkan supported GPU or with installed software implementation of Vulkan, like https://github.com/google/swiftshader

## Vulkan execution
The initial implementation is trivial and waits every operator's execution.
Pull Request resolved: pytorch#36491

Differential Revision: D21696709

Pulled By: IvanKobzarev

fbshipit-source-id: da3e5a770b1a1995e9465d7e81963e7de56217fa
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed May 26, 2020
1 parent 1fa0bb6 commit b460465
Show file tree
Hide file tree
Showing 53 changed files with 4,923 additions and 10 deletions.
15 changes: 15 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ option(USE_SNPE "Use Qualcomm's SNPE library" OFF)
option(USE_SYSTEM_EIGEN_INSTALL
"Use system Eigen instead of the one under third_party" OFF)
option(USE_TENSORRT "Using Nvidia TensorRT library" OFF)
option(USE_VULKAN "Use Vulkan GPU backend" OFF)
option(USE_VULKAN_WRAPPER "Use Vulkan wrapper" ON)
option(USE_VULKAN_SHADERC_RUNTIME "Use Vulkan Shader compilation runtime(Needs shaderc lib)" OFF)
option(USE_XNNPACK "Use XNNPACK" ON)
option(USE_ZMQ "Use ZMQ" OFF)
option(USE_ZSTD "Use ZSTD" OFF)
Expand Down Expand Up @@ -475,6 +478,18 @@ if(USE_XNNPACK)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL")
endif()

if(USE_VULKAN)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_VULKAN")
endif()

if(USE_VULKAN_WRAPPER)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_VULKAN_WRAPPER")
endif()

if(USE_VULKAN_SHADERC_RUNTIME)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_VULKAN_SHADERC_RUNTIME")
endif()

# ---[ Whitelist file if whitelist is specified
include(cmake/Whitelist.cmake)

Expand Down
5 changes: 5 additions & 0 deletions aten/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(ATen_HIP_SRCS)
set(ATen_HIP_SRCS_W_SORT_BY_KEY)
set(ATen_HIP_TEST_SRCS)
set(ATen_HIP_INCLUDE)
set(ATen_VULKAN_TEST_SRCS)
set(ATen_CPU_DEPENDENCY_LIBS)
set(ATen_CUDA_DEPENDENCY_LIBS)
set(ATen_HIP_DEPENDENCY_LIBS)
Expand All @@ -51,6 +52,9 @@ set(TH_CPU_INCLUDE
${CMAKE_BINARY_DIR}/aten/src)
list(APPEND ATen_CPU_INCLUDE ${TH_CPU_INCLUDE})

if(USE_VULKAN)
list(APPEND ATen_CPU_INCLUDE ${CMAKE_BINARY_DIR}/vulkan)
endif()

# Find the HIP package, set the HIP paths, load the HIP CMake.
if(USE_ROCM)
Expand Down Expand Up @@ -113,6 +117,7 @@ set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE)
set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE)
set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE)
set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE)
set(ATen_VULKAN_TEST_SRCS ${ATen_VULKAN_TEST_SRCS} PARENT_SCOPE)
set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE)
set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE)
set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE)
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ file(GLOB mkldnn_cpp "mkldnn/*.cpp")
file(GLOB native_cpp "native/*.cpp")
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
file(GLOB native_vulkan_cpp "native/vulkan/*.cpp")
file(GLOB native_vulkan_stub_cpp "native/vulkan/stub/*.cpp")
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
file(GLOB native_quantized_cpp
"native/quantized/*.cpp"
Expand Down Expand Up @@ -105,6 +107,11 @@ endif()
if(AT_MKLDNN_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
endif()
if(USE_VULKAN)
set(all_cpu_cpp ${all_cpu_cpp} ${native_vulkan_cpp} ${vulkan_generated_cpp})
else()
set(all_cpu_cpp ${all_cpu_cpp} ${native_vulkan_stub_cpp})
endif()

if(USE_CUDA AND USE_ROCM)
message(FATAL_ERROR "ATen doesn't not currently support simultaneously building with CUDA and ROCM")
Expand Down Expand Up @@ -324,6 +331,7 @@ endif()
# Include CPU paths for CUDA/HIP as well
list(APPEND ATen_CUDA_INCLUDE ${ATen_CPU_INCLUDE})
list(APPEND ATen_HIP_INCLUDE ${ATen_CPU_INCLUDE})
list(APPEND ATen_VULKAN_INCLUDE ${ATen_CPU_INCLUDE})

# We have two libraries: libATen_cpu.so and libATen_cuda.so,
# with libATen_cuda.so depending on libATen_cpu.so. The CPU library
Expand Down Expand Up @@ -402,11 +410,13 @@ set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE)
set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE)
set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)
set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE)
set(ATen_VULKAN_TEST_SRCS ${ATen_VULKAN_TEST_SRCS} PARENT_SCOPE)
set(ATen_QUANTIZED_TEST_SRCS ${ATen_QUANTIZED_TEST_SRCS} PARENT_SCOPE)
set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE)
set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE)
set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE)
set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE)
set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE)
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
31 changes: 26 additions & 5 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def TypedDict(name, attrs, total=True): # type: ignore
break;
""")

IFDEF_BLOCK = CodeTemplate("""\
#ifdef ${ifdef_guard}
${content}
#endif
""")

# add a native declaration for a native function
NATIVE_DECLARATION = CodeTemplate("""\
CAFFE2_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults});
Expand Down Expand Up @@ -221,7 +227,8 @@ def TypedDict(name, attrs, total=True): # type: ignore
('ComplexDouble', 'ComplexDouble', 'ComplexDouble', False),
]

static_dispatch_backends = ['CPU', 'QuantizedCPU']
static_dispatch_backends = ['CPU', 'QuantizedCPU', 'Vulkan']
static_dispatch_backends_ifdef_guard = {'Vulkan' : 'USE_VULKAN'}


class NYIError(Exception):
Expand Down Expand Up @@ -1059,11 +1066,18 @@ def swizzle_self(f): # blegh
# calling code.
for backend in static_dispatch_backends:
if backend in type_method_dispatch:
static_dispatch_function_cases.append(STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
static_dispatch_function_case = STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
option,
backend=backend,
backend_function=type_method_dispatch[backend],
actuals=option['method_actuals']))
actuals=option['method_actuals'])
if (backend in static_dispatch_backends_ifdef_guard):
static_dispatch_function_cases.append(IFDEF_BLOCK.substitute(
option,
ifdef_guard=static_dispatch_backends_ifdef_guard[backend],
content=static_dispatch_function_case))
else:
static_dispatch_function_cases.append(static_dispatch_function_case)

static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
option,
Expand Down Expand Up @@ -1094,11 +1108,18 @@ def gen_namespace_function(option, multidispatch_formals):
static_dispatch_function_cases = []
for backend in static_dispatch_backends:
if backend in type_method_dispatch:
static_dispatch_function_cases.append(STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
static_dispatch_function_case = STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
option,
backend=backend,
backend_function=type_method_dispatch[backend],
actuals=option['actuals']))
actuals=option['actuals'])
if (backend in static_dispatch_backends_ifdef_guard):
static_dispatch_function_cases.append(IFDEF_BLOCK.substitute(
option,
ifdef_guard=static_dispatch_backends_ifdef_guard[backend],
content=static_dispatch_function_case))
else:
static_dispatch_function_cases.append(static_dispatch_function_case)
static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
option,
dispatch_key_var_name=dispatch_key_var_name,
Expand Down
9 changes: 8 additions & 1 deletion aten/src/ATen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
'--rocm',
action='store_true',
help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly')
parser.add_argument(
'--vulkan',
action='store_true',
help='Generate Vulkan backend functions')
parser.add_argument(
'--op_registration_whitelist',
nargs='*',
Expand All @@ -67,6 +71,7 @@
help='force it to generate schema-only registrations for all ops, including'
'those that are not listed on --op_registration_whitelist')
options = parser.parse_args()

# NB: It is mandatory to NOT use os.path.join here, as the install directory
# will eventually be ingested by cmake, which does not respect Windows style
# path slashes. If you switch this to use os.path.join, you'll get an error
Expand Down Expand Up @@ -365,7 +370,7 @@ def generate_storage_type_and_tensor(backend, density, declarations, per_op_regi
fm.write(env['Type'] + ".cpp", SPARSE_TYPE_DERIVED_CPP, env)
fm.write(env['Type'] + ".h", TYPE_DERIVED_H, env)

if env['DeviceType'] == 'CPU':
if env['DeviceType'] == 'CPU' or env['DeviceType'] == 'Vulkan':
top_env['cpu_type_headers'].append(
'#include <ATen/{}.h>'.format(env['Type']))
else:
Expand All @@ -384,6 +389,8 @@ def iterate_types():
yield (backend, density)
for backend in quantized_backends:
yield (backend, 'Dense')
if options.vulkan:
yield('Vulkan', 'Dense')


def gen_per_op_registration_filename(opname):
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#if AT_NNPACK_ENABLED()
#include <nnpack.h>
#endif
#ifdef USE_VULKAN
#include <ATen/native/vulkan/VulkanAten.h>
#endif


constexpr int MIOPEN_DIM_MAX = 5;
Expand Down Expand Up @@ -47,6 +50,7 @@ struct ConvParams {
bool use_mkldnn(const at::Tensor& input) const;
bool use_nnpack(const at::Tensor& input) const;
bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) const;
bool use_vulkan(const at::Tensor& input, const at::Tensor& weight) const;
bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
};

Expand Down Expand Up @@ -274,6 +278,20 @@ auto ConvParams::use_xnnpack(
return false;
}

auto ConvParams::use_vulkan(
const at::Tensor &input, const at::Tensor& weight) const -> bool {
#ifdef USE_VULKAN
if (!(input.is_vulkan() && input.scalar_type() == kFloat &&
!transposed && input.ndimension() == 4)) {
return false;
}
return (groups == 1) || (input.size(1) == groups && groups > 1 &&
weight.size(0) % input.size(1) == 0);
#else
return false;
#endif
}

// We currently only have depthwise support for the case where groups ==
// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
// a depthwise multiplier)
Expand Down Expand Up @@ -669,6 +687,12 @@ at::Tensor _convolution(
output = at::miopen_depthwise_convolution(
input.contiguous(), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
#ifdef USE_VULKAN
} else if (params.use_vulkan(input, weight)) {
output = at::native::vulkan_convolution(
input, weight, bias,
params.padding, params.stride, params.dilation, params.groups);
#endif
} else {
output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation);
}
Expand Down Expand Up @@ -761,6 +785,12 @@ at::Tensor _convolution(
bias,
params.stride,
params.padding);
#ifdef USE_VULKAN
} else if (params.use_vulkan(input, weight)) {
output = at::native::vulkan_convolution(
input, weight, bias,
params.padding, params.stride, params.dilation, params.groups);
#endif
} else if (input.device().type() == c10::DeviceType::CPU || input.device().type() == c10::DeviceType::CUDA) {
if (params.groups == 1) {
output = at::_convolution_nogroup(
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <ATen/NamedTensorUtils.h>
#include <torch/library.h>

#ifdef USE_VULKAN
#include <ATen/native/vulkan/VulkanAten.h>
#endif
namespace {

using namespace at;
Expand Down Expand Up @@ -78,7 +81,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
// (e.g. XLA) may be supported by overriding copy_ and _copy_from.
bool is_supported_device(Device device) {
DeviceType device_type = device.type();
return device_type == kCPU || device_type == kCUDA || device_type == kHIP;
return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan;
}

} // namespace
Expand Down Expand Up @@ -126,6 +129,12 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
TORCH_CHECK(false, "Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor");
}

#ifdef USE_VULKAN
if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) {
return vulkan_copy_(self, src);
}
#endif

auto iter = TensorIterator();
iter.set_check_mem_overlap(true);
iter.add_output(self);
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b
return self;
}

if (options.device().type() == DeviceType::Vulkan
|| self.device().type() == DeviceType::Vulkan) {
auto r = at::empty(self.sizes(), options, c10::nullopt);
r.copy_(self, non_blocking);
return r;
}

if (memory_format == MemoryFormat::Preserve) {
if (self.is_non_overlapping_and_dense()) {
// Copy all strides
Expand Down Expand Up @@ -62,6 +69,13 @@ Tensor to(
"to(options) expects unset requires_grad flag, but got "
"options.requires_grad set as ", options.requires_grad());

if (options.device().type() == DeviceType::Vulkan
|| self.device().type() == DeviceType::Vulkan) {
auto r = at::empty(self.sizes(), options, c10::nullopt);
r.copy_(self, non_blocking);
return r;
}

TORCH_CHECK(!options.has_layout() || self.layout() == options.layout(),
"to(options) doesn't support converting to a different layout, "
"but got self.layout being ", self.layout(),
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@
SparseCPU: add_sparse
SparseCUDA: add_sparse
MkldnnCPU: mkldnn_add
Vulkan: vulkan_add
supports_named_tensor: True

- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
Expand Down Expand Up @@ -764,6 +765,7 @@
CPU: clamp
CUDA: clamp
QuantizedCPU: quantized_clamp
Vulkan: vulkan_clamp

- func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
supports_named_tensor: True
Expand Down Expand Up @@ -1183,6 +1185,7 @@
MkldnnCPU: empty_mkldnn
SparseCPU: empty_sparse
SparseCUDA: empty_sparse
Vulkan: empty_vulkan

- func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: method
Expand Down Expand Up @@ -1923,6 +1926,7 @@
CPU: mean_cpu_gpu
CUDA: mean_cpu_gpu
QuantizedCPU: quantized_mean_cpu
Vulkan: mean_vulkan

- func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
supports_named_tensor: True
Expand Down Expand Up @@ -2231,6 +2235,9 @@
CPU: batch_norm_update_stats_cpu
CUDA: batch_norm_update_stats_cuda

- func: is_vulkan_available() -> bool
use_c10_dispatcher: full

- func: _nnpack_available() -> bool
use_c10_dispatcher: full

Expand Down Expand Up @@ -3476,6 +3483,7 @@
CUDA: addmm_cuda
SparseCPU: addmm_sparse_dense_cpu
SparseCUDA: addmm_sparse_dense_cuda
Vulkan: vulkan_addmm
supports_named_tensor: True

- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
Expand Down Expand Up @@ -5962,6 +5970,7 @@
CPU: hardtanh_
CUDA: hardtanh_
QuantizedCPU: quantized_hardtanh_
Vulkan: vulkan_hardtanh_

- func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
Expand Down Expand Up @@ -6705,6 +6714,7 @@
CPU: upsample_nearest2d_cpu
CUDA: upsample_nearest2d_cuda
QuantizedCPU: quantized_upsample_nearest2d_cpu
Vulkan: upsample_nearest2d_vulkan

- func: upsample_nearest2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
Expand Down
Loading

0 comments on commit b460465

Please sign in to comment.