Skip to content

Commit

Permalink
fix hip support (alibaba#939)
Browse files Browse the repository at this point in the history
* add dcu support

* resolve the conflict

* add dcu support

* fix hip support
  • Loading branch information
zhangxiao-stack authored Jan 6, 2023
1 parent 3d80282 commit 4987cd1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pytorch_blade/scripts/build_pytorch_blade_dcu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function ci_build() {
elif [ "$TORCH_BLADE_BUILD_WITH_DCU_ROCM_SUPPORT" = "ON" ]; then
export TORCH_BLADE_BUILD_TENSORRT=OFF
export TORCH_BLADE_BUILD_TENSORRT_STATIC=${TORCH_BLADE_BUILD_TENSORRT_STATIC:-OFF}
python3 ../scripts/python/common_setup.py
python3 ../scripts/python/common_setup.py --rocm_path=/opt/dtk
else
python3 ../scripts/python/common_setup.py --cpu_only
fi
Expand Down
11 changes: 9 additions & 2 deletions tao_compiler/mlir/xla/ral/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ load(
"if_cuda_is_configured",
"cuda_library",
)
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl","if_dcu", "if_rocm_is_configured")
load(
"@com_google_protobuf//:protobuf.bzl",
"cc_proto_library",
Expand Down Expand Up @@ -323,7 +323,8 @@ cc_library(
copts = [
# "-DTF_1_12",
"-fopenmp",
] + if_cuda_or_rocm(["-DTAO_RAL_USE_STREAM_EXECUTOR"]),
] + if_dcu(["-DTENSORFLOW_USE_DCU=1"])
+ if_cuda_or_rocm(["-DTAO_RAL_USE_STREAM_EXECUTOR"]),
linkopts = [
"-fopenmp",
"-ldl",
Expand Down Expand Up @@ -508,6 +509,8 @@ cc_library(
]),
copts = if_cuda_is_configured([
"-DGOOGLE_CUDA=1"
]) + if_dcu([
"-DTENSORFLOW_USE_DCU=1"
]) + if_cuda_or_rocm(["-DTAO_RAL_USE_STREAM_EXECUTOR"]),

alwayslink = 1,
Expand Down Expand Up @@ -695,6 +698,8 @@ cc_library(
"-DDISC_BUILD_FROM_TF_BRIDGE"
] + if_cuda_or_rocm([
"-DTAO_RAL_USE_STREAM_EXECUTOR"
]) + if_dcu([
"-DTENSORFLOW_USE_DCU=1"
]) + if_rocm_is_configured([
"-DTENSORFLOW_USE_ROCM=1",
"-x rocm",
Expand Down Expand Up @@ -908,6 +913,8 @@ cc_library(
] + if_rocm_is_configured([
"-DTENSORFLOW_USE_ROCM=1",
"-x rocm",
]) + if_dcu([
"-DTENSORFLOW_USE_DCU=1"
]) + if_cuda_or_rocm([
"-DTAO_RAL_USE_STREAM_EXECUTOR"
]) + cuda_default_copts(),
Expand Down
18 changes: 12 additions & 6 deletions tao_compiler/mlir/xla/ral/context/stream_executor_based_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ static bool DoGemmWithAlgorithm(
/*leading dim of LHS=*/lhs_matrix.num_cols,
/*beta=*/static_cast<OutT>(beta), &output_data,
/*leading dim of output=*/n, computation_type, *algorithm,
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM && \
(!TENSORFLOW_USE_DCU)
output_profile_result, se::blas::CallContext::kNone)
#elif (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 11)
se::blas::kDefaultComputePrecision, output_profile_result)
Expand All @@ -225,7 +226,8 @@ static bool DoGemmWithAlgorithm(
/*leading dim of LHS=*/lhs_matrix.num_cols, lhs_stride,
/*beta=*/static_cast<AlphaBeta>(beta), &output_data,
/*leading dim of output=*/n, output_stride,
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM && \
(!TENSORFLOW_USE_DCU)
batch_size, se::blas::CallContext::kNone)
#elif (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 11)
batch_size, se::blas::kDefaultComputePrecision)
Expand All @@ -242,7 +244,8 @@ static bool DoGemmWithAlgorithm(
/*leading dim of RHS=*/rhs_matrix.num_cols, lhs_data,
/*leading dim of LHS=*/lhs_matrix.num_cols,
/*beta=*/static_cast<AlphaBeta>(beta), &output_data,
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM && \
(!TENSORFLOW_USE_DCU)
/*leading dim of output=*/n, se::blas::CallContext::kNone)
#else
/*leading dim of output=*/n)
Expand Down Expand Up @@ -738,7 +741,8 @@ std::vector<ProfileResult> GetMIOpenAlgorithms(
params.input_descriptor, operand_buffers[0], params.filter_descriptor,
operand_buffers[1], params.output_descriptor, result_buffer,
params.convolution_descriptor, scratch_allocator,
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM && \
(!TENSORFLOW_USE_DCU)
se::dnn::CallContext::kNone, &algorithms)) {
#else
&algorithms)) {
Expand Down Expand Up @@ -1340,13 +1344,15 @@ Status RunCudnnConvolution(CudnnConvParams& params,
}

Status status = Status::OK();
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM && \
(!TENSORFLOW_USE_DCU)
se::dnn::CallContext call_context = se::dnn::CallContext::kNone;
#endif

switch (kind) {
case ConvolutionKind::FORWARD:
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM
#if (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 8) && TENSORFLOW_USE_ROCM && \
(!TENSORFLOW_USE_DCU)
// TF2.9 and ROCM
call_context = se::dnn::CallContext::kForward;
status = stream->ConvolveWithAlgorithm(
Expand Down

0 comments on commit 4987cd1

Please sign in to comment.