Skip to content

Commit

Permalink
Fix crash reported in microsoft#4070. (microsoft#4091)
Browse files Browse the repository at this point in the history
* Fix crash reported in microsoft#4070.

* Add newline to warning message

* Add comment for using cout instead of the logger
  • Loading branch information
pranavsharma authored Jun 1, 2020
1 parent 8813d20 commit 6c1b2f3
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 9 deletions.
14 changes: 14 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ endif()

set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/shared_lib")
set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${ONNXRUNTIME_ROOT}/test/global_thread_pools")
set (ONNXRUNTIME_API_TESTS_WITHOUT_ENV_SRC_DIR "${ONNXRUNTIME_ROOT}/test/api_tests_without_env")

set (onnxruntime_shared_lib_test_SRC
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
Expand All @@ -235,6 +236,9 @@ set (onnxruntime_global_thread_pools_test_SRC
${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_main.cc
${ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR}/test_inference.cc)

set (onnxruntime_api_tests_without_env_SRC
${ONNXRUNTIME_API_TESTS_WITHOUT_ENV_SRC_DIR}/test_apis_without_env.cc)

# tests from lowest level library up.
# the order of libraries should be maintained, with higher libraries being added first in the list

Expand Down Expand Up @@ -766,6 +770,16 @@ if (onnxruntime_BUILD_SHARED_LIB)
DEPENDS ${all_dependencies}
)
endif()

# A separate test is needed to test the APIs that don't rely on the env being created first.
if (NOT CMAKE_SYSTEM_NAME STREQUAL "Android")
AddTest(DYN
TARGET onnxruntime_api_tests_without_env
SOURCES ${onnxruntime_api_tests_without_env_SRC}
LIBS ${onnxruntime_shared_lib_test_LIBS}
DEPENDS ${all_dependencies}
)
endif()
endif()

#some ETW tools
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/framework/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "core/framework/data_types.h"
#include "core/framework/allocatormgr.h"
#include "core/providers/dnnl/dnnl_provider_factory.h"
#include "core/session/inference_session.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/ort_apis.h"
#include "core/platform/env.h"
Expand Down
15 changes: 8 additions & 7 deletions onnxruntime/core/session/abi_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,12 @@ ORT_API_STATUS_IMPL(OrtApis::SetIntraOpNumThreads, _Inout_ OrtSessionOptions* op
#ifdef _OPENMP
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(intra_op_num_threads);
LOGS_DEFAULT(WARNING) << "Since openmp is enabled in this build, this API cannot be used to configure"
" intra op num threads. Please use the openmp environment variables to control"
" the number of threads.";
// Can't use the default logger here since it's possible that the default logger has not been created
// at this point. The default logger gets created when the env is created and these APIs don't require
// the env to be created first.
std::cout << "WARNING: Since openmp is enabled in this build, this API cannot be used to configure"
" intra op num threads. Please use the openmp environment variables to control"
" the number of threads.\n";
#else
options->value.intra_op_param.thread_pool_size = intra_op_num_threads;
#endif
Expand All @@ -161,16 +164,14 @@ ORT_API_STATUS_IMPL(OrtApis::SetInterOpNumThreads, _Inout_ OrtSessionOptions* op
ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options,
_In_ const char* dim_denotation, _In_ int64_t dim_value) {
options->value.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{dim_denotation, onnxruntime::FreeDimensionOverrideType::Denotation, dim_value}
);
onnxruntime::FreeDimensionOverride{dim_denotation, onnxruntime::FreeDimensionOverrideType::Denotation, dim_value});
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options,
_In_ const char* dim_name, _In_ int64_t dim_value) {
options->value.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{dim_name, onnxruntime::FreeDimensionOverrideType::Name, dim_value}
);
onnxruntime::FreeDimensionOverride{dim_name, onnxruntime::FreeDimensionOverrideType::Name, dim_value});
return nullptr;
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/abi_session_options_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <string>
#include <vector>
#include <atomic>
#include "core/session/inference_session.h"
#include "core/framework/session_options.h"
#include "core/session/onnxruntime_c_api.h"
#include "core/providers/providers.h"

Expand Down
69 changes: 69 additions & 0 deletions onnxruntime/test/api_tests_without_env/test_apis_without_env.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifndef USE_ONNXRUNTIME_DLL
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#else
#pragma warning(push)
#pragma warning(disable : 4018) /*'expression' : signed/unsigned mismatch */
#pragma warning(disable : 4065) /*switch statement contains 'default' but no 'case' labels*/
#pragma warning(disable : 4100)
#pragma warning(disable : 4146) /*unary minus operator applied to unsigned type, result still unsigned*/
#pragma warning(disable : 4127)
#pragma warning(disable : 4244) /*'conversion' conversion from 'type1' to 'type2', possible loss of data*/
#pragma warning(disable : 4251) /*'identifier' : class 'type' needs to have dll-interface to be used by clients of class 'type2'*/
#pragma warning(disable : 4267) /*'var' : conversion from 'size_t' to 'type', possible loss of data*/
#pragma warning(disable : 4305) /*'identifier' : truncation from 'type1' to 'type2'*/
#pragma warning(disable : 4307) /*'operator' : integral constant overflow*/
#pragma warning(disable : 4309) /*'conversion' : truncation of constant value*/
#pragma warning(disable : 4334) /*'operator' : result of 32-bit shift implicitly converted to 64 bits (was 64-bit shift intended?)*/
#pragma warning(disable : 4355) /*'this' : used in base member initializer list*/
#pragma warning(disable : 4506) /*no definition for inline function 'function'*/
#pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/
#pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/
#pragma warning(disable : 6011) /*Dereferencing NULL pointer*/
#pragma warning(disable : 6387) /*'value' could be '0'*/
#pragma warning(disable : 26495) /*Variable is uninitialized.*/
#endif
#include <google/protobuf/message_lite.h>
#ifdef __GNUC__
#pragma GCC diagnostic pop
#else
#pragma warning(pop)
#endif
#endif

#include "gtest/gtest.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/abi_session_options_impl.h"

TEST(TestSessionOptions, SetIntraOpNumThreadsWithoutEnv) {
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(48);
const auto* ort_session_options = (const OrtSessionOptions*)session_options;
#ifdef _OPENMP
ASSERT_EQ(ort_session_options->value.intra_op_param.thread_pool_size, 0);
#else
ASSERT_EQ(ort_session_options->value.intra_op_param.thread_pool_size, 48);
#endif
}

int main(int argc, char** argv) {
int status = 0;
try {
::testing::InitGoogleTest(&argc, argv);
status = RUN_ALL_TESTS();
} catch (const std::exception& ex) {
std::cerr << ex.what();
status = -1;
}

#ifndef USE_ONNXRUNTIME_DLL
//make memory leak checker happy
::google::protobuf::ShutdownProtobufLibrary();
#endif
return status;
}
1 change: 0 additions & 1 deletion onnxruntime/test/shared_lib/test_fixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ typedef const char* PATH_TYPE;
static inline void ORT_API_CALL MyLoggingFunction(void*, OrtLoggingLevel, const char*, const char*, const char*, const char*) {
}


0 comments on commit 6c1b2f3

Please sign in to comment.