Skip to content

Commit

Permalink
hipRTC fixes (#1531)
Browse files Browse the repository at this point in the history
Added CMakeFlag for hipRTC. MIGRAPHX_USE_HIPRTC.
Added stages in Jenkins for hipRTC.
Fixes for some of the pending issues from hipRTC.
  • Loading branch information
umangyadav authored Jan 31, 2023
1 parent a4b8265 commit 91cc724
Show file tree
Hide file tree
Showing 15 changed files with 202 additions and 139 deletions.
6 changes: 6 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ def rocmtestnode(Map conf) {
def compiler = bconf.get("compiler", "/opt/rocm/llvm/bin/clang++")
def flags = bconf.get("flags", "")
def gpu_debug = bconf.get("gpu_debug", "0")
def hiprtc_workarounds = bconf.get("hiprtc_workarounds", "0")
def cmd = """
ulimit -c unlimited
echo "leak:dnnl::impl::malloc" > suppressions.txt
export LSAN_OPTIONS="suppressions=\$(pwd)/suppressions.txt"
export MIGRAPHX_GPU_DEBUG=${gpu_debug}
export MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=${hiprtc_workarounds}
export CXX=${compiler}
export CXXFLAGS='-Werror'
env
Expand Down Expand Up @@ -110,6 +112,10 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release")
stash includes: 'build/*.deb', name: 'migraphx-package'
}
}, hiprtc_gpu_debug: rocmnode('vega') { cmake_build ->
stage('HipRTC GPU Debug') {
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On", gpu_debug: true, hiprtc_workarounds: true)
}
}, mlir_debug: rocmnode('vega') { cmake_build ->
stage('MLIR Debug') {
def sanitizers = "undefined"
Expand Down
108 changes: 55 additions & 53 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#####################################################################################
# ####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
Expand All @@ -20,7 +20,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
# ####################################################################################

list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc)
find_package(miopen)
Expand All @@ -33,6 +33,8 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()

set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")

include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
Expand All @@ -46,9 +48,10 @@ add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)

if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device")
target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device)
message(STATUS "Enable -fhip-lambda-host-device")
target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device)
endif()

set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
Expand All @@ -60,11 +63,13 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)

add_library(kernel_file_check EXCLUDE_FROM_ALL)

foreach(KERNEL_FILE ${KERNEL_FILES})
get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach()

target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256)
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu)
Expand Down Expand Up @@ -125,6 +130,7 @@ function(register_migraphx_gpu_ops PREFIX)
register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp)
endforeach()
endfunction()

register_migraphx_gpu_ops(hip_
argmax
argmin
Expand All @@ -146,47 +152,41 @@ register_migraphx_gpu_ops(miopen_
lrn
pooling
)
register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu)

# look for offload bundler
get_filename_component(CMAKE_CXX_COMPILER_PATH "${CMAKE_CXX_COMPILER}" PATH)
if(CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$")
find_program(MIGRAPHX_OFFLOADBUNDLER_BIN clang-offload-bundler
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATH_SUFFIXES bin
PATHS /opt/rocm/llvm
)
else()

if(NOT CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$")
find_program(MIGRAPHX_EXTRACT_KERNEL extractkernel
PATH_SUFFIXES bin
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATHS
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
)
endif()

message(STATUS "clang-offload-bundler: ${MIGRAPHX_OFFLOADBUNDLER_BIN}")
message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")

set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")

if(MIGRAPHX_ENABLE_MLIR)
# Find package rocMLIR
find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
Expand All @@ -195,36 +195,39 @@ if(MIGRAPHX_ENABLE_MLIR)
target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler)
endif()

set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
if(MIGRAPHX_USE_HIPRTC)
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
else()
# Get flags needed to compile hip
include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "MIGraphX is using HIP Clang")

message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_OFFLOADBUNDLER_BIN=${MIGRAPHX_OFFLOADBUNDLER_BIN}"
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
"-DMIGRAPHX_USE_HIPRTC=0"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
# Get flags needed to compile hip
include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device)

# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")

# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")

foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()

message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
)

if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif()

# Check miopen find mode api
Expand All @@ -236,7 +239,7 @@ check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_
# TODO: Set default to HAS_FIND_2_API
set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")

if(MIGRAPHX_USE_FIND_2_API)
if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else()
Expand All @@ -258,8 +261,7 @@ target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
add_subdirectory(driver)

rocm_install_targets(
TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE
TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include
)

45 changes: 30 additions & 15 deletions src/targets/gpu/compile_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@
#include <cassert>
#include <iostream>

#if MIGRAPHX_USE_HIPRTC
#ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/env.hpp>
#else
#include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp>
Expand All @@ -48,9 +47,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);

#if MIGRAPHX_USE_HIPRTC
#ifdef MIGRAPHX_USE_HIPRTC

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);

std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{
Expand Down Expand Up @@ -143,25 +143,29 @@ struct hiprtc_program
options.end(),
std::back_inserter(c_options),
[](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
std::cerr << log() << std::endl;
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log();
if(not prog_log.empty())
{
std::cerr << prog_log << std::endl;
}
if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Compilation failed.");
}

std::string log()
std::string log() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n < 2)
if(n == 0)
return {};
std::vector<char> buffer(n);
std::string buffer(n, '\0');
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() == 0);
return {buffer.begin(), buffer.end() - 1};
assert(buffer.back() != 0);
return buffer;
}

std::vector<char> get_code_obj()
std::vector<char> get_code_obj() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
Expand All @@ -176,14 +180,25 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
hiprtc_program prog(srcs);
auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
{
options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}

if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
return starts_with(s, "--std=") or starts_with(s, "-std=");
}))
options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat");
options.push_back("--offload-arch=" + arch);
prog.compile(options);
Expand Down Expand Up @@ -292,15 +307,15 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)};
}

#endif // MIGRAPHX_USE_HIPRTC

std::string enum_params(std::size_t count, std::string param)
{
std::vector<std::string> items(count);
transform(range(count), items.begin(), [&](auto i) { return param + std::to_string(i); });
return join_strings(items, ",");
}

#endif // MIGRAPHX_USE_HIPRTC

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
2 changes: 1 addition & 1 deletion src/targets/gpu/compile_hip_code_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <migraphx/context.hpp>
#include <migraphx_kernels.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -80,6 +79,7 @@ std::string generate_args_hpp(const std::vector<shape>& inputs)
#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/types.hpp>
namespace migraphx {
Expand Down
8 changes: 3 additions & 5 deletions src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace gpu {
namespace device {

#ifdef MIGRAPHX_NO_DPP

template <index_int N,
class Op,
class T,
Expand All @@ -62,6 +63,7 @@ __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
}
return buffer[0];
}

#else
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }

Expand Down Expand Up @@ -96,11 +98,7 @@ __device__ T dpp_mov(T& x)
input.data = x;
for(index_int i = 0; i < n; i++)
{
#if defined(__HCC__)
output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
#else
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
#endif
}
return output.data;
}
Expand Down Expand Up @@ -310,4 +308,4 @@ void reduce(hipStream_t stream,
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
#endif // MIGRAPHX_NO_DPP
2 changes: 1 addition & 1 deletion src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ constexpr auto array_for_each(T& x, Ts&... xs)
}
else
{
using vec_type = std::remove_reference_t<decltype(array2vec(x))>;
using vec_type = remove_reference_t<decltype(array2vec(x))>;
f(array2vec(x), __builtin_convertvector(array2vec(xs), vec_type)...);
}
}
Expand Down
Loading

0 comments on commit 91cc724

Please sign in to comment.