Skip to content

Commit

Permalink
Add deterministic path for AllReduceL2 (used to compute gradient norm) (
Browse files Browse the repository at this point in the history
#5027)

* add deterministic path for reduce l2

* add unit tests

* memset zero size off by one

* eliminate windows warning as error

Co-authored-by: suffian khan <sukha@OrtTrainingDev1.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
  • Loading branch information
Suffian Khan and suffian khan authored Sep 3, 2020
1 parent 9ba2cfb commit 546965c
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 29 deletions.
1 change: 1 addition & 0 deletions onnxruntime/test/providers/provider_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ void OpTester::Run(
so.session_logid = op_;
so.session_log_verbosity_level = 1;
so.execution_mode = execution_mode;
so.use_deterministic_compute = use_determinism_;
so.graph_optimization_level = TransformerLevel::Default; // 'Default' == off
Run(so, expect_result, expected_failure_string, excluded_provider_types,
run_options, execution_providers, custom_output_verifier, options);
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/test/providers/provider_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ class OpTester {
return output_data_;
}

void SetDeterminism(bool use_determinism) {
use_determinism_ = use_determinism;
}

protected:
virtual void AddNodes(onnxruntime::Graph& graph, std::vector<onnxruntime::NodeArg*>& graph_input_defs,
std::vector<onnxruntime::NodeArg*>& graph_output_defs,
Expand Down Expand Up @@ -619,6 +623,8 @@ class OpTester {
std::vector<std::shared_ptr<CustomRegistry>> custom_session_registries_;

bool verify_output_;

bool use_determinism_ = false;
};

template <typename TException>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,14 @@ TEST(AllOpTest, All_1d_large) {
}
}

TEST(ReductionOpTest, ReduceAllL2) {
class ReductionOpTest : public ::testing::TestWithParam<bool> {
protected:
bool use_determinism;
};

TEST_P(ReductionOpTest, ReduceAllL2) {
OpTester test("ReduceAllL2", 1, onnxruntime::kMSDomain, true);
test.SetDeterminism(GetParam());
std::vector<float> data0 = {1.0f, 2.0f, 3.0f};
std::vector<float> data1 = {-1.0f, -2.0f};

Expand All @@ -96,8 +102,9 @@ TEST(ReductionOpTest, ReduceAllL2) {
test.Run();
}

TEST(ReductionOpTest, ReduceAllL2HalfHalf) {
TEST_P(ReductionOpTest, ReduceAllL2HalfHalf) {
OpTester test("ReduceAllL2", 1, onnxruntime::kMSDomain, true);
test.SetDeterminism(GetParam());

std::vector<float> data0 = {1.0f, 2.0f, 3.0f};
std::vector<MLFloat16> data0_half(3);
Expand All @@ -118,8 +125,9 @@ TEST(ReductionOpTest, ReduceAllL2HalfHalf) {
test.Run();
}

TEST(ReductionOpTest, ReduceAllL2FloatHalf) {
TEST_P(ReductionOpTest, ReduceAllL2FloatHalf) {
OpTester test("ReduceAllL2", 1, onnxruntime::kMSDomain, true);
test.SetDeterminism(GetParam());

std::vector<float> data0 = {1.0f, 2.0f, 3.0f};
std::vector<float> data1 = {-1.0f, -2.0f};
Expand All @@ -135,8 +143,9 @@ TEST(ReductionOpTest, ReduceAllL2FloatHalf) {
test.Run();
}

TEST(ReductionOpTest, ReduceAllL2HalfFloat) {
TEST_P(ReductionOpTest, ReduceAllL2HalfFloat) {
OpTester test("ReduceAllL2", 1, onnxruntime::kMSDomain, true);
test.SetDeterminism(GetParam());

std::vector<float> data0 = {1.0f, 2.0f, 3.0f};
std::vector<MLFloat16> data0_half(3);
Expand All @@ -160,8 +169,10 @@ void TestMultiTensorReduce(
const int min_tensor_size,
const int max_tensor_size,
const float min,
const float max) {
const float max,
bool use_determinism) {
OpTester test("ReduceAllL2", 1, onnxruntime::kMSDomain, true);
test.SetDeterminism(use_determinism);

// Set up random number generator.
std::random_device random_device;
Expand Down Expand Up @@ -196,22 +207,25 @@ void TestMultiTensorReduce(
test.Run();
}

TEST(ReductionOpTest, ReduceAllL2LargeOne) {
TestMultiTensorReduce(16, 1, 131072, 1.f, 1.f);
TEST_P(ReductionOpTest, ReduceAllL2LargeOne) {
TestMultiTensorReduce(16, 1, 131072, 1.f, 1.f, GetParam());
}

TEST(ReductionOpTest, ReduceAllL2Large) {
TestMultiTensorReduce(16, 1, 131072, 1.2f, 1.3f);
TEST_P(ReductionOpTest, ReduceAllL2Large) {
TestMultiTensorReduce(16, 1, 131072, 1.2f, 1.3f, GetParam());
}

TEST(ReductionOpTest, ReduceAllL2ManyOne) {
TestMultiTensorReduce(4096, 1, 8, 1.f, 1.f);
TEST_P(ReductionOpTest, ReduceAllL2ManyOne) {
TestMultiTensorReduce(4096, 1, 8, 1.f, 1.f, GetParam());
}

TEST(ReductionOpTest, ReduceAllL2Many) {
TestMultiTensorReduce(4096, 1, 8, 1.2f, 1.3f);
TEST_P(ReductionOpTest, ReduceAllL2Many) {
TestMultiTensorReduce(4096, 1, 8, 1.2f, 1.3f, GetParam());
}

// invoke with and without use_determinism flag for session
INSTANTIATE_TEST_SUITE_P(ReductionOpTestWrapper, ReductionOpTest, ::testing::Bool());

#endif

TEST(ReductionOpTest, ReduceSumTraining_int32) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@
// Licensed under the MIT License.

#include "orttraining/training_ops/cuda/reduction/reduction_all.h"
#include "core/providers/cuda/reduction/reduction_functions.h"
#include "core/framework/op_kernel_context_internal.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
struct AccumulateType {};
template <>
struct AccumulateType<float> { using type = float; };
template <>
struct AccumulateType<half> { using type = float; };
template <>
struct AccumulateType<double> { using type = double; };
template <typename T>
using AccType = typename AccumulateType<T>::type;

#define REGISTER_REDUCE_ALL_KERNEL_TYPED(Name, TIn, TOut) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Name, \
Expand Down Expand Up @@ -42,17 +55,56 @@ Status ReduceAllL2<TIn, TOut>::ComputeInternal(OpKernelContext* ctx) const {
CudaTOut* p_output = reinterpret_cast<CudaTOut*>(output->template MutableData<TOut>());
ORT_ENFORCE(cudaMemset(p_output, 0, sizeof(CudaTOut)) == cudaSuccess);

typedef MultiTensorReduceL2<CudaTIn, CudaTOut> TFunctor;
TFunctor functor;
auto ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
bool deterministic = ctx_internal && ctx_internal->GetUseDeterministicCompute();

if (!deterministic) {

typedef MultiTensorReduceL2<CudaTIn, CudaTOut> TFunctor;
TFunctor functor;

// Check if all values are finite and write true to deviceOutput.
// Otherwise, false will be written.
launch_multi_tensor_functor<1, TFunctor, CudaTOut*>(
2048 * 32, tensor_sizes, grouped_tensor_pointers, functor, p_output);

// *p_output is the squared sum of all elements.
// Let's take a sqrt to get the actual L2-norm.
ScalarSqrt(p_output, p_output);
}
else {

// Check if all values are finite and write true to deviceOutput.
// Otherwise, false will be written.
launch_multi_tensor_functor<1, TFunctor, CudaTOut*>(
2048 * 32, tensor_sizes, grouped_tensor_pointers, functor, p_output);
// alternate path only for deterministic compute ..
typedef AccType<CudaTOut> CudaTAcc;

// *p_output is the squared sum of all elements.
// Let's take a sqrt to get the actual L2-norm.
ScalarSqrt(p_output, p_output);
// find scratch buffer size needed by 'reduce_square_sum' for each tensor
int scratch_size = 0;
for (int i = 0; i < total_tensor_count; ++i) {
scratch_size = std::max(scratch_size, compute_reduction_buffer_size(sizeof(CudaTAcc), tensor_sizes[i]));
}

// enlarge scratch buffer size for 'reduce_sum' over tensor square norms
scratch_size = std::max(scratch_size, compute_reduction_buffer_size(sizeof(CudaTAcc), total_tensor_count));

// add head room for final output and square norms of each tensor
scratch_size += (1 + total_tensor_count)*sizeof(CudaTAcc);

// create GPU scratch space and zero target for each tensor square norm
uint8_t* p_scratch = GetScratchBuffer<uint8_t>(scratch_size).get();
ORT_ENFORCE(cudaMemset(p_scratch, 0, sizeof(CudaTAcc)*(1 + total_tensor_count)) == cudaSuccess);

CudaTAcc* p_global_sqnorm = reinterpret_cast<CudaTAcc*>(p_scratch);
CudaTAcc* p_tensor_sqnorm = p_global_sqnorm + 1;
CudaTAcc* p_reduce_buffer = p_tensor_sqnorm + total_tensor_count;

// perform reduction l2norm = sqrt[sum(tensor[i][j]**2)] for i,j over all tensor elements
for (int i = 0; i < total_tensor_count; ++i) {
CudaTIn* p_tensor_i = reinterpret_cast<CudaTIn*>(grouped_tensor_pointers[i][0]);
reduce_square_sum(p_tensor_i, p_tensor_sqnorm + i, tensor_sizes[i], p_reduce_buffer);
}
reduce_sum(p_tensor_sqnorm, p_global_sqnorm, total_tensor_count, p_reduce_buffer);
ScalarSqrt(p_global_sqnorm, p_output);
}

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
namespace onnxruntime {
namespace cuda {

template<typename T>
__global__ void _ScalarSqrtImpl(T* input, T* output) {
*output = _Sqrt(*input);
template<typename Tin, typename Tout>
__global__ void _ScalarSqrtImpl(Tin* input, Tout* output) {
*output = (Tout)_Sqrt(*input);
};

template<typename T>
void ScalarSqrt(T* input, T* output) {
template<typename Tin, typename Tout>
void ScalarSqrt(Tin* input, Tout* output) {
_ScalarSqrtImpl<<<1, 1, 0>>>(input, output);
};

template void ScalarSqrt(float* input, float* output);
template void ScalarSqrt(half* input, half* output);
template void ScalarSqrt(float* input, half* output);

template <typename TIn, typename TOut, typename TBuf, typename TInOp, typename TOutOp>
__global__ void _MultiTensorReduceImpl(ChunkGroup<1> chunk_group, TOut* output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ struct MultiTensorReduceL2 {
void operator()(ChunkGroup<1> chunk_group, TOut* output);
};

template<typename T>
void ScalarSqrt(T* input, T* output);
template<typename Tin, typename Tout>
void ScalarSqrt(Tin* input, Tout* output);

} // namespace cuda
} // namespace onnxruntime

0 comments on commit 546965c

Please sign in to comment.