Skip to content

Commit c8ac34d

Browse files
authored
Fix DEBUG_NODE_INPUTS_OUTPUTS test by putting it in a separate process, clean up unused test_main.cc files. (#5949)
Move the DEBUG_NODE_INPUTS_OUTPUTS test into its own process. The implementation uses static variables which do not interact well with other tests. Clean up old test_main.cc files which are no longer used.
1 parent a53f4dd commit c8ac34d

File tree

10 files changed

+87
-176
lines changed

10 files changed

+87
-176
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 68 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ function(AddTest)
139139
endif()
140140
endfunction(AddTest)
141141

142+
# general program entrypoint for C++ unit tests
143+
set(onnxruntime_unittest_main_src "${TEST_SRC_DIR}/unittest_main/test_main.cc")
144+
142145
#Do not add '${TEST_SRC_DIR}/util/include' to your include directories directly
143146
#Use onnxruntime_add_include_to_target or target_link_libraries, so that compile definitions
144147
#can propagate correctly.
@@ -571,65 +574,58 @@ endif()
571574

572575
set(all_dependencies ${onnxruntime_test_providers_dependencies} )
573576

574-
if (onnxruntime_ENABLE_TRAINING)
575-
list(APPEND all_tests ${onnxruntime_test_training_src})
576-
endif()
577+
if (onnxruntime_ENABLE_TRAINING)
578+
list(APPEND all_tests ${onnxruntime_test_training_src})
579+
endif()
577580

578-
if (onnxruntime_USE_TVM)
579-
list(APPEND all_tests ${onnxruntime_test_tvm_src})
580-
endif()
581-
if (onnxruntime_USE_OPENVINO)
582-
list(APPEND all_tests ${onnxruntime_test_openvino_src})
583-
endif()
584-
# we can only have one 'main', so remove them all and add back the providers test_main as it sets
585-
# up everything we need for all tests
586-
file(GLOB_RECURSE test_mains CONFIGURE_DEPENDS
587-
"${TEST_SRC_DIR}/*/test_main.cc"
588-
)
589-
list(REMOVE_ITEM all_tests ${test_mains})
590-
list(APPEND all_tests "${TEST_SRC_DIR}/providers/test_main.cc")
581+
if (onnxruntime_USE_TVM)
582+
list(APPEND all_tests ${onnxruntime_test_tvm_src})
583+
endif()
584+
if (onnxruntime_USE_OPENVINO)
585+
list(APPEND all_tests ${onnxruntime_test_openvino_src})
586+
endif()
591587

592-
# this is only added to onnxruntime_test_framework_libs above, but we use onnxruntime_test_providers_libs for the onnxruntime_test_all target.
593-
# for now, add it here. better is probably to have onnxruntime_test_providers_libs use the full onnxruntime_test_framework_libs
594-
# list given it's built on top of that library and needs all the same dependencies.
595-
if(WIN32)
596-
list(APPEND onnxruntime_test_providers_libs Advapi32)
597-
endif()
588+
# this is only added to onnxruntime_test_framework_libs above, but we use onnxruntime_test_providers_libs for the onnxruntime_test_all target.
589+
# for now, add it here. better is probably to have onnxruntime_test_providers_libs use the full onnxruntime_test_framework_libs
590+
# list given it's built on top of that library and needs all the same dependencies.
591+
if(WIN32)
592+
list(APPEND onnxruntime_test_providers_libs Advapi32)
593+
endif()
598594

599-
AddTest(
600-
TARGET onnxruntime_test_all
601-
SOURCES ${all_tests}
602-
LIBS onnx_test_runner_common ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} re2::re2 onnx_test_data_proto
603-
DEPENDS ${all_dependencies}
604-
)
595+
AddTest(
596+
TARGET onnxruntime_test_all
597+
SOURCES ${all_tests} ${onnxruntime_unittest_main_src}
598+
LIBS onnx_test_runner_common ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} re2::re2 onnx_test_data_proto
599+
DEPENDS ${all_dependencies}
600+
)
605601

606-
# the default logger tests conflict with the need to have an overall default logger
607-
# so skip in this type of
608-
target_compile_definitions(onnxruntime_test_all PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
609-
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
610-
target_compile_definitions(onnxruntime_test_all_xc PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
611-
endif()
612-
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
613-
target_compile_options(onnxruntime_test_all PUBLIC "-Wno-unused-const-variable")
614-
endif()
615-
if(onnxruntime_RUN_MODELTEST_IN_DEBUG_MODE)
616-
target_compile_definitions(onnxruntime_test_all PUBLIC -DRUN_MODELTEST_IN_DEBUG_MODE)
617-
endif()
618-
if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
619-
target_compile_definitions(onnxruntime_test_all PRIVATE DEBUG_NODE_INPUTS_OUTPUTS)
620-
endif()
621-
if (onnxruntime_USE_FEATURIZERS)
622-
target_include_directories(onnxruntime_test_all PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external/FeaturizersLibrary/src)
623-
endif()
624-
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
625-
target_link_libraries(onnxruntime_test_all PRIVATE onnxruntime_language_interop onnxruntime_pyop)
626-
endif()
602+
# the default logger tests conflict with the need to have an overall default logger
603+
# so skip in this type of
604+
target_compile_definitions(onnxruntime_test_all PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
605+
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
606+
target_compile_definitions(onnxruntime_test_all_xc PUBLIC -DSKIP_DEFAULT_LOGGER_TESTS)
607+
endif()
608+
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
609+
target_compile_options(onnxruntime_test_all PUBLIC "-Wno-unused-const-variable")
610+
endif()
611+
if(onnxruntime_RUN_MODELTEST_IN_DEBUG_MODE)
612+
target_compile_definitions(onnxruntime_test_all PUBLIC -DRUN_MODELTEST_IN_DEBUG_MODE)
613+
endif()
614+
if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
615+
target_compile_definitions(onnxruntime_test_all PRIVATE DEBUG_NODE_INPUTS_OUTPUTS)
616+
endif()
617+
if (onnxruntime_USE_FEATURIZERS)
618+
target_include_directories(onnxruntime_test_all PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/external/FeaturizersLibrary/src)
619+
endif()
620+
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
621+
target_link_libraries(onnxruntime_test_all PRIVATE onnxruntime_language_interop onnxruntime_pyop)
622+
endif()
627623

628-
if (onnxruntime_USE_ROCM)
629-
target_include_directories(onnxruntime_test_all PRIVATE ${onnxruntime_ROCM_HOME}/include/hiprand ${onnxruntime_ROCM_HOME}/include/rocrand)
630-
endif()
624+
if (onnxruntime_USE_ROCM)
625+
target_include_directories(onnxruntime_test_all PRIVATE ${onnxruntime_ROCM_HOME}/include/hiprand ${onnxruntime_ROCM_HOME}/include/rocrand)
626+
endif()
631627

632-
set(test_data_target onnxruntime_test_all)
628+
set(test_data_target onnxruntime_test_all)
633629

634630

635631
#
@@ -872,7 +868,7 @@ if (onnxruntime_BUILD_SHARED_LIB)
872868
endif()
873869
AddTest(DYN
874870
TARGET onnxruntime_shared_lib_test
875-
SOURCES ${onnxruntime_shared_lib_test_SRC} ${TEST_SRC_DIR}/providers/test_main.cc
871+
SOURCES ${onnxruntime_shared_lib_test_SRC} ${onnxruntime_unittest_main_src}
876872
LIBS ${onnxruntime_shared_lib_test_LIBS}
877873
DEPENDS ${all_dependencies}
878874
)
@@ -905,6 +901,24 @@ if (onnxruntime_BUILD_SHARED_LIB)
905901
endif()
906902
endif()
907903

904+
# the debug node IO functionality uses static variables, so it is best tested
905+
# in its own process
906+
if(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
907+
AddTest(
908+
TARGET onnxruntime_test_debug_node_inputs_outputs
909+
SOURCES
910+
"${TEST_SRC_DIR}/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc"
911+
"${TEST_SRC_DIR}/framework/TestAllocatorManager.cc"
912+
"${TEST_SRC_DIR}/providers/provider_test_utils.cc"
913+
${onnxruntime_unittest_main_src}
914+
LIBS ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs}
915+
DEPENDS ${all_dependencies}
916+
)
917+
918+
target_compile_definitions(onnxruntime_test_debug_node_inputs_outputs
919+
PRIVATE DEBUG_NODE_INPUTS_OUTPUTS)
920+
endif(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
921+
908922
#some ETW tools
909923
if(WIN32 AND onnxruntime_ENABLE_INSTRUMENT)
910924
add_executable(generate_perf_report_from_etl ${ONNXRUNTIME_ROOT}/tool/etw/main.cc

onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -176,41 +176,30 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() {
176176
static const NodeDumpOptions node_dump_options = []() {
177177
namespace env_vars = debug_node_inputs_outputs_env_vars;
178178

179-
auto get_bool_env_var = [](const char* env_var) {
180-
const auto val = Env::Default().GetEnvironmentVar(env_var);
181-
if (val.empty()) return false;
182-
std::istringstream s{val};
183-
int i;
184-
ORT_ENFORCE(
185-
s >> i && s.eof(),
186-
"Failed to parse environment variable ", env_var, ": ", val);
187-
return i != 0;
188-
};
189-
190179
NodeDumpOptions opts{};
191180

192181
// Preserve existing behavior of printing the shapes by default. Turn it off only if the user has requested so
193182
// explicitly by setting the value of the env variable to 0.
194183
opts.dump_flags = NodeDumpOptions::DumpFlags::None;
195-
if (ParseEnvironmentVariable<bool>(env_vars::kDumpShapeData, true)) {
184+
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpShapeData, true)) {
196185
opts.dump_flags |= NodeDumpOptions::DumpFlags::Shape;
197186
}
198187

199-
if (get_bool_env_var(env_vars::kDumpInputData)) {
188+
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpInputData, false)) {
200189
opts.dump_flags |= NodeDumpOptions::DumpFlags::InputData;
201190
}
202-
if (get_bool_env_var(env_vars::kDumpOutputData)) {
191+
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpOutputData, false)) {
203192
opts.dump_flags |= NodeDumpOptions::DumpFlags::OutputData;
204193
}
205194

206195
opts.filter.name_pattern = Env::Default().GetEnvironmentVar(env_vars::kNameFilter);
207196
opts.filter.op_type_pattern = Env::Default().GetEnvironmentVar(env_vars::kOpTypeFilter);
208197

209-
if (get_bool_env_var(env_vars::kDumpDataToFiles)) {
198+
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpDataToFiles, false)) {
210199
opts.data_destination = NodeDumpOptions::DataDestination::TensorProtoFiles;
211200
}
212201

213-
if (get_bool_env_var(env_vars::kAppendRankToFileName)) {
202+
if (ParseEnvironmentVariableWithDefault<bool>(env_vars::kAppendRankToFileName, false)) {
214203
std::string rank = Env::Default().GetEnvironmentVar("OMPI_COMM_WORLD_RANK");
215204
if (rank.empty()) {
216205
opts.file_suffix = "_default_rank_0";
@@ -229,7 +218,7 @@ const NodeDumpOptions& NodeDumpOptionsFromEnvironmentVariables() {
229218
opts.data_destination == NodeDumpOptions::DataDestination::TensorProtoFiles &&
230219
opts.filter.name_pattern.empty() && opts.filter.op_type_pattern.empty()) {
231220
ORT_ENFORCE(
232-
get_bool_env_var(env_vars::kDumpingDataToFilesForAllNodesIsOk),
221+
ParseEnvironmentVariableWithDefault<bool>(env_vars::kDumpingDataToFilesForAllNodesIsOk, false),
233222
"The current environment variable configuration will dump node input or output data to files for every node. "
234223
"This may cause a lot of files to be generated. Set the environment variable ",
235224
env_vars::kDumpingDataToFilesForAllNodesIsOk, " to confirm this is what you want.");

onnxruntime/core/platform/env_var_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ optional<T> ParseEnvironmentVariable(const std::string& name) {
3333
* Parses an environment variable value or returns the given default if unavailable.
3434
*/
3535
template <typename T>
36-
T ParseEnvironmentVariable(const std::string& name, const T& default_value) {
36+
T ParseEnvironmentVariableWithDefault(const std::string& name, const T& default_value) {
3737
const auto parsed = ParseEnvironmentVariable<T>(name);
3838
if (parsed.has_value()) {
3939
return parsed.value();

onnxruntime/test/common/test_main.cc

Lines changed: 0 additions & 24 deletions
This file was deleted.

onnxruntime/test/framework/debug_node_inputs_outputs_utils_test.cc renamed to onnxruntime/test/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
5-
64
#include "core/framework/debug_node_inputs_outputs_utils.h"
75

86
#include <fstream>
@@ -41,11 +39,14 @@ TEST(DebugNodeInputsOutputs, BasicFileOutput) {
4139
TemporaryDirectory temp_dir{ORT_TSTR("debug_node_inputs_outputs_utils_test")};
4240
ScopedEnvironmentVariables scoped_env_vars{
4341
EnvVarMap{
44-
{env_vars::kDumpInputData, {"1"}},
45-
{env_vars::kDumpOutputData, {"1"}},
46-
{env_vars::kDumpDataToFiles, {"1"}},
47-
{env_vars::kOutputDir, {ToMBString(temp_dir.Path())}},
48-
{env_vars::kDumpingDataToFilesForAllNodesIsOk, {"1"}},
42+
{env_vars::kDumpInputData, "1"},
43+
{env_vars::kDumpOutputData, "1"},
44+
{env_vars::kNameFilter, nullopt},
45+
{env_vars::kOpTypeFilter, nullopt},
46+
{env_vars::kDumpDataToFiles, "1"},
47+
{env_vars::kAppendRankToFileName, nullopt},
48+
{env_vars::kOutputDir, ToMBString(temp_dir.Path())},
49+
{env_vars::kDumpingDataToFilesForAllNodesIsOk, "1"},
4950
}};
5051

5152
OpTester tester{"Round", 11, kOnnxDomain};
@@ -56,8 +57,10 @@ TEST(DebugNodeInputsOutputs, BasicFileOutput) {
5657

5758
auto verify_file_data =
5859
[&temp_dir, &input, &output](
59-
const std::vector<OrtValue>& /*fetches*/,
60+
const std::vector<OrtValue>& fetches,
6061
const std::string& /*provider_type*/) {
62+
ASSERT_EQ(fetches.size(), 1u);
63+
FetchTensor(fetches[0]);
6164
VerifyTensorProtoFileData(
6265
temp_dir.Path() + ORT_TSTR("/x.tensorproto"),
6366
gsl::make_span(input));
@@ -73,5 +76,3 @@ TEST(DebugNodeInputsOutputs, BasicFileOutput) {
7376

7477
} // namespace test
7578
} // namespace onnxruntime
76-
77-
#endif

onnxruntime/test/framework/test_main.cc

Lines changed: 0 additions & 25 deletions
This file was deleted.

onnxruntime/test/ir/test_main.cc

Lines changed: 0 additions & 31 deletions
This file was deleted.

onnxruntime/test/shared_lib/test_main.cc

Lines changed: 0 additions & 13 deletions
This file was deleted.

onnxruntime/test/util/test_random_seed.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ RandomSeedType GetTestRandomSeed() {
2424
};
2525

2626
static const auto use_cached =
27-
!ParseEnvironmentVariable<bool>(test_random_seed_env_vars::kDoNotCache, false);
27+
!ParseEnvironmentVariableWithDefault<bool>(test_random_seed_env_vars::kDoNotCache, false);
2828
if (use_cached) {
2929
// initially generate from current time
3030
static const auto static_random_seed = generate_from_time();

0 commit comments

Comments
 (0)