-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize FastGelu with float2 and float4 vectorized kernels on ROCm (#…
…11491) * Using vectorized loads (float2) for fp16 to improve performance * Fix a few warnings from cpplint * Fix a few warnings from cpplint * Use __float2half2_rn and fix some cpplint warnings * Move some computaions to LaunchFastGeluKernel * Fix some Lint C++ warning * Using vectorized loads (float4) for fp16 to improve performance * Switch whether to optimize FastGelu with float4 vectorization * Switch to float4 memory access based on input_length in FastGelu * Comment how to set the threshold of float2 and float4 vectorized kernels * Add FastGelu fp16 unit tests for bias_length = 2 and 8 * Make vectorized kernels generic with aligned_vector * Unify the vectorized kernels with/without bias * Refactor the code to suppress cpplint warnings * Solve formatting issues * Remove cudaDeviceProp from FastGeluKernel and LaunchFastGeluKernel * Move fast_gelu_impl.h to rocm/bert * Fix some Lint C++ warnings and code alignment
- Loading branch information
1 parent
088bc74
commit f4ba199
Showing
6 changed files
with
321 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
// Modifications: Remove GetDeviceProp in LaunchFastGeluKernel. | ||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/rocm/rocm_common.h" | ||
#include "core/providers/rocm/miopen_common.h" | ||
#include "contrib_ops/rocm/bert/fast_gelu.h" | ||
#include "contrib_ops/rocm/bert/fast_gelu_impl.h" | ||
#include "contrib_ops/cpu/bert/bias_gelu_helper.h" | ||
#include "contrib_ops/rocm/bert/transformer_common.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace rocm { | ||
|
||
#define REGISTER_KERNEL_TYPED(T) \ | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( \ | ||
FastGelu, \ | ||
kMSDomain, \ | ||
1, \ | ||
T, \ | ||
kRocmExecutionProvider, \ | ||
(*KernelDefBuilder::Create()) \ | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \ | ||
FastGelu<T>); | ||
|
||
REGISTER_KERNEL_TYPED(float) | ||
REGISTER_KERNEL_TYPED(MLFloat16) | ||
REGISTER_KERNEL_TYPED(BFloat16) | ||
|
||
using namespace ONNX_NAMESPACE; | ||
|
||
template <typename T> | ||
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { | ||
const TransformerOptions* options = TransformerOptions::GetInstance(); | ||
use_half2_ = !options->DisableHalf2(); | ||
} | ||
|
||
template <typename T> | ||
Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const { | ||
ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); | ||
|
||
const Tensor* input = context->Input<Tensor>(0); | ||
const Tensor* bias = context->Input<Tensor>(1); | ||
Tensor* output = context->Output(0, input->Shape()); | ||
|
||
int64_t input_length = input->Shape().Size(); | ||
if (input_length == 0) { | ||
return Status::OK(); | ||
} | ||
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); | ||
typedef typename ToHipType<T>::MappedType HipT; | ||
|
||
if (!LaunchFastGeluKernel<HipT>(Stream(), | ||
static_cast<int>(input_length), | ||
static_cast<int>(bias_length), | ||
reinterpret_cast<const HipT*>(input->template Data<T>()), | ||
(nullptr != bias) ? reinterpret_cast<const HipT*>(bias->template Data<T>()) : nullptr, | ||
reinterpret_cast<HipT*>(output->template MutableData<T>()), | ||
use_half2_)) { | ||
HIP_CALL(hipGetLastError()); | ||
return Status(common::ONNXRUNTIME, common::FAIL); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
|
||
} // namespace rocm | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/common/common.h" | ||
#include "core/providers/rocm/rocm_kernel.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace rocm { | ||
|
||
using namespace onnxruntime::rocm; | ||
|
||
template <typename T> | ||
class FastGelu final : public RocmKernel { | ||
public: | ||
FastGelu(const OpKernelInfo& op_kernel_info); | ||
Status ComputeInternal(OpKernelContext* ctx) const override; | ||
|
||
private: | ||
bool use_half2_; | ||
}; | ||
|
||
} // namespace rocm | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.