Skip to content

Commit f2b3e84

Browse files
alekstheodtensorflower-gardener
authored andcommitted
PR tensorflow#24898: [ROCM] fix asan invalid memory access in redzone allocator kernel rocm cu
Imported from GitHub PR openxla/xla#24898 Fix issue reported by asan while running the tests on rocm ci: ``` ==1718600==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x5030001d97f8 at pc 0x5647cfdda211 bp 0x7ffc9eb7eac0 sp 0x7ffc9eb7eab8 READ of size 8 at 0x5030001d97f8 thread T0 #0 0x5647cfdda210 in absl::lts_20230802::container_internal::CommonFields::capacity() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:990:36 #1 0x5647cfdda210 in void absl::lts_20230802::container_internal::InitializeSlots<std::allocator<char>, 8ul, 8ul>(absl::lts_20230802::container_internal::CommonFields&, std::allocator<char>) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:1403:24 #2 0x7f066c2cfdde in absl::lts_20230802::container_internal::raw_hash_set<absl::lts_20230802::container_internal::NodeHashMapPolicy<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>, absl::lts_20230802::hash_internal::Hash<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::equal_to<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::allocator<std::pair<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*> const, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>>>::resize(unsigned long) (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x9dde) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #3 0x7f066c2cfd97 in absl::lts_20230802::container_internal::raw_hash_set<absl::lts_20230802::container_internal::NodeHashMapPolicy<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>, absl::lts_20230802::hash_internal::Hash<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::equal_to<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::allocator<std::pair<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*> const, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>>>::prepare_insert(unsigned long) (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x9d97) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #4 0x7f066c2cfcca in std::pair<unsigned long, bool> absl::lts_20230802::container_internal::raw_hash_set<absl::lts_20230802::container_internal::NodeHashMapPolicy<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>, absl::lts_20230802::hash_internal::Hash<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::equal_to<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::allocator<std::pair<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*> const, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>>>::find_or_prepare_insert<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>(std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*> const&) (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x9cca) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #5 0x7f066c2cf9c4 in std::pair<absl::lts_20230802::container_internal::raw_hash_set<absl::lts_20230802::container_internal::NodeHashMapPolicy<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>, absl::lts_20230802::hash_internal::Hash<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::equal_to<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::allocator<std::pair<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*> const, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>>>::iterator, bool> absl::lts_20230802::container_internal::raw_hash_set<absl::lts_20230802::container_internal::NodeHashMapPolicy<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>, absl::lts_20230802::hash_internal::Hash<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::equal_to<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>>, std::allocator<std::pair<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*> const, stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>>>>::EmplaceDecomposable::operator()<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>, std::piecewise_construct_t const&, std::tuple<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>&>, std::tuple<stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>&&>>(std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*> const&, std::piecewise_construct_t const&, std::tuple<std::tuple<stream_executor::StreamExecutor*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char>>, void*>&>&&, std::tuple<stream_executor::TypedKernel<stream_executor::DeviceMemory<unsigned char>, unsigned char, unsigned long, stream_executor::DeviceMemory<unsigned long>>&&>&&) const (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x99c4) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #6 0x7f066c2cf0ad in stream_executor::GetComparisonKernel(stream_executor::StreamExecutor*, stream_executor::GpuAsmOpts) (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/../../../_solib_local/libxla_Sstream_Uexecutor_Sgpu_Slibredzone_Uallocator_Ukernel_Urocm_Urocm.so+0x90ad) (BuildId: 3bd12bfb947fb25a2a780cc09bea1d9c) #7 0x7f066c37ba93 in stream_executor::RedzoneAllocator::CheckRedzones() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/stream_executor/gpu/redzone_allocator.cc:272:3 #8 0x7f06b31bb7e9 in absl::lts_20230802::StatusOr<xla::AutotuneResult> xla::gpu::(anonymous namespace)::GemmAutotuner::GetBestAlgorithm<long, xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&)::'lambda'(long const&)&>(xla::HloInstruction const*, absl::lts_20230802::Span<long const>, double, bool, xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&)::'lambda'(long const&)&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:328:7 #9 0x7f06b31bb7e9 in xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:256:12 #10 0x7f06b31bb7e9 in xla::gpu::(anonymous namespace)::GemmAutotuner::operator()(xla::HloInstruction const*, xla::gpu::AutotuneCacheKey const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:137:18 #11 0x7f06b31b6760 in xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0::operator()() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:418:3 #12 0x7f06b31b6760 in absl::lts_20230802::StatusOr<xla::AutotuneResult> std::__invoke_impl<absl::lts_20230802::StatusOr<xla::AutotuneResult>, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>(std::__invoke_other, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:61:14 #13 0x7f06b31b6760 in std::enable_if<is_invocable_r_v<absl::lts_20230802::StatusOr<xla::AutotuneResult>, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>, absl::lts_20230802::StatusOr<xla::AutotuneResult>>::type std::__invoke_r<absl::lts_20230802::StatusOr<xla::AutotuneResult>, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>(xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:114:9 #14 0x7f06b31b6760 in std::_Function_handler<absl::lts_20230802::StatusOr<xla::AutotuneResult> (), xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0>::_M_invoke(std::_Any_data const&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:290:9 #15 0x7f06b308670d in std::function<absl::lts_20230802::StatusOr<xla::AutotuneResult> ()>::operator()() const /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:590:9 #16 0x7f06b308670d in xla::gpu::AutotunerUtil::Autotune(xla::HloInstruction const*, xla::gpu::AutotuneConfig const&, std::function<absl::lts_20230802::StatusOr<xla::AutotuneResult> ()> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/autotuner_util.cc:460:3 #17 0x7f06b31b336e in xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:418:3 #18 0x7f06b31b336e in xla::gpu::(anonymous namespace)::RunOnComputation(xla::HloComputation*, xla::gpu::(anonymous namespace)::GemmAutotuner&, unsigned long*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:468:7 #19 0x7f06b31b336e in xla::gpu::GemmAlgorithmPicker::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char>>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:495:5 #20 0x7f06b30242f3 in xla::HloPassPipeline::RunHelper(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char>>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/hlo/pass/hlo_pass_pipeline.h:150:5 #21 0x7f06b3010bb9 in absl::lts_20230802::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char>>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/hlo/pass/hlo_pass_pipeline.cc:198:30 #22 0x7f06b300f786 in xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char>>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/hlo/pass/hlo_pass_pipeline.cc:338:10 #23 0x5647cfd66945 in xla::HloPassInterface::Run(xla::HloModule*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/hlo/pass/hlo_pass_interface.h:85:12 #24 0x7f06c2908be0 in xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1754:3 #25 0x7f06c2a000f3 in xla::gpu::AMDGPUCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/amdgpu_compiler.cc:197:3 #26 0x7f06c28f85e9 in xla::gpu::GpuCompiler::OptimizeHloModule(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1392:3 #27 0x7f06c291250d in xla::gpu::GpuCompiler::RunHloPasses(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule>>, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1824:3 #28 0x5647cfd63784 in xla::Compiler::RunHloPasses(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule>>, stream_executor::StreamExecutor*, stream_executor::DeviceMemoryAllocator*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/service/compiler.h:177:12 #29 0x7f06c339acba in xla::HloTestBase::GetOptimizedModule(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule>>) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/tests/hlo_test_base.cc:188:32 #30 0x5647cfd89516 in xla::gpu::(anonymous namespace)::GpuCompilerTest_CollectivePermuteDecompositionAndPipelining_Test::TestBody() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler_test.cc:879:3 #31 0x7f06c2c649dd in void testing::internal::HandleSehExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #32 0x7f06c2c649dd in void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #33 0x7f06c2c64708 in testing::Test::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2739:5 #34 0x7f06c2c6771b in testing::TestInfo::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2885:11 #35 0x7f06c2c6a5ab in testing::TestSuite::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:3063:30 #36 0x7f06c2c96eba in testing::internal::UnitTestImpl::RunAllTests() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:6054:44 #37 0x7f06c2c9579d in bool testing::internal::HandleSehExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #38 0x7f06c2c9579d in bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #39 0x7f06c2c95203 in testing::UnitTest::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:5594:10 #40 0x7f06c2d679b8 in RUN_ALL_TESTS() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/include/gtest/gtest.h:2334:73 #41 0x7f06c2d679b8 in main /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/tests/xla_internal_test_main.cc:65:10 #42 0x7f064c0b3d8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16 #43 0x7f064c0b3e3f in __libc_start_main csu/../csu/libc-start.c:392:3 #44 0x5647cfc7b044 in _start (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/gpu_compiler_test_gpu_amd_any+0xff044) (BuildId: ef1ac485eb61840d0e2233a2cca69eec) 0x5030001d97f8 is located 8 bytes before 32-byte region [0x5030001d9800,0x5030001d9820) allocated by thread T0 here: #0 0x5647cfd1527f in malloc (/root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/service/gpu/gpu_compiler_test_gpu_amd_any+0x19927f) (BuildId: ef1ac485eb61840d0e2233a2cca69eec) #1 0x7f064c39798b in operator new(unsigned long) (/lib/x86_64-linux-gnu/libstdc++.so.6+0xae98b) (BuildId: e37fe1a879783838de78cbc8c80621fa685d58a2) #2 0x7f06b31bb5b7 in google::protobuf::Duration* google::protobuf::MessageLite::CreateMaybeMessage<google::protobuf::Duration>(google::protobuf::Arena*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_protobuf/src/google/protobuf/message_lite.h:425:12 #3 0x7f06b31bb5b7 in xla::AutotuneResult::_internal_mutable_run_time() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/autotuning.pb.h:3079:15 #4 0x7f06b31bb5b7 in xla::AutotuneResult::mutable_run_time() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/bazel-out/k8-opt/bin/xla/autotuning.pb.h:3085:45 #5 0x7f06b31bb5b7 in absl::lts_20230802::StatusOr<xla::AutotuneResult> xla::gpu::(anonymous namespace)::GemmAutotuner::GetBestAlgorithm<long, xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&)::'lambda'(long const&)&>(xla::HloInstruction const*, absl::lts_20230802::Span<long const>, double, bool, xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&)::'lambda'(long const&)&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:321:15 #6 0x7f06b31bb5b7 in xla::gpu::(anonymous namespace)::GemmAutotuner::TuneGpuBlas(xla::HloInstruction const*, xla::gpu::GemmConfig const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:256:12 #7 0x7f06b31bb5b7 in xla::gpu::(anonymous namespace)::GemmAutotuner::operator()(xla::HloInstruction const*, xla::gpu::AutotuneCacheKey const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:137:18 #8 0x7f06b31b6760 in xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0::operator()() const /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:418:3 #9 0x7f06b31b6760 in absl::lts_20230802::StatusOr<xla::AutotuneResult> std::__invoke_impl<absl::lts_20230802::StatusOr<xla::AutotuneResult>, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>(std::__invoke_other, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:61:14 #10 0x7f06b31b6760 in std::enable_if<is_invocable_r_v<absl::lts_20230802::StatusOr<xla::AutotuneResult>, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>, absl::lts_20230802::StatusOr<xla::AutotuneResult>>::type std::__invoke_r<absl::lts_20230802::StatusOr<xla::AutotuneResult>, xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&>(xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:114:9 #11 0x7f06b31b6760 in std::_Function_handler<absl::lts_20230802::StatusOr<xla::AutotuneResult> (), xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&)::$_0>::_M_invoke(std::_Any_data const&) /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:290:9 #12 0x7f06b308670d in std::function<absl::lts_20230802::StatusOr<xla::AutotuneResult> ()>::operator()() const /usr/lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:590:9 #13 0x7f06b308670d in xla::gpu::AutotunerUtil::Autotune(xla::HloInstruction const*, xla::gpu::AutotuneConfig const&, std::function<absl::lts_20230802::StatusOr<xla::AutotuneResult> ()> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/autotuner_util.cc:460:3 #14 0x7f06b31b336e in xla::gpu::(anonymous namespace)::RunOnInstruction(xla::HloInstruction*, xla::gpu::(anonymous namespace)::GemmAutotuner&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:418:3 #15 0x7f06b31b336e in xla::gpu::(anonymous namespace)::RunOnComputation(xla::HloComputation*, xla::gpu::(anonymous namespace)::GemmAutotuner&, unsigned long*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:468:7 #16 0x7f06b31b336e in xla::gpu::GemmAlgorithmPicker::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char>>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc:495:5 #17 0x7f06b30242f3 in xla::HloPassPipeline::RunHelper(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char>>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/hlo/pass/hlo_pass_pipeline.h:150:5 #18 0x7f06b3010bb9 in absl::lts_20230802::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char>>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/hlo/pass/hlo_pass_pipeline.cc:198:30 #19 0x7f06b300f786 in xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char>>>> const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/hlo/pass/hlo_pass_pipeline.cc:338:10 #20 0x5647cfd66945 in xla::HloPassInterface::Run(xla::HloModule*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/hlo/pass/hlo_pass_interface.h:85:12 #21 0x7f06c2908be0 in xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1754:3 #22 0x7f06c2a000f3 in xla::gpu::AMDGPUCompiler::OptimizeHloPostLayoutAssignment(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&, tsl::thread::ThreadPool*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/amdgpu_compiler.cc:197:3 #23 0x7f06c28f85e9 in xla::gpu::GpuCompiler::OptimizeHloModule(xla::HloModule*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::Compiler::TargetConfig const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1392:3 #24 0x7f06c291250d in xla::gpu::GpuCompiler::RunHloPasses(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule>>, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler.cc:1824:3 #25 0x5647cfd63784 in xla::Compiler::RunHloPasses(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule>>, stream_executor::StreamExecutor*, stream_executor::DeviceMemoryAllocator*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/./xla/service/compiler.h:177:12 #26 0x7f06c339acba in xla::HloTestBase::GetOptimizedModule(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule>>) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/tests/hlo_test_base.cc:188:32 #27 0x5647cfd89516 in xla::gpu::(anonymous namespace)::GpuCompilerTest_CollectivePermuteDecompositionAndPipelining_Test::TestBody() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/service/gpu/gpu_compiler_test.cc:879:3 #28 0x7f06c2c649dd in void testing::internal::HandleSehExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #29 0x7f06c2c649dd in void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #30 0x7f06c2c64708 in testing::Test::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2739:5 #31 0x7f06c2c6771b in testing::TestInfo::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2885:11 #32 0x7f06c2c6a5ab in testing::TestSuite::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:3063:30 #33 0x7f06c2c96eba in testing::internal::UnitTestImpl::RunAllTests() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:6054:44 #34 0x7f06c2c9579d in bool testing::internal::HandleSehExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2664:10 #35 0x7f06c2c9579d in bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:2700:14 #36 0x7f06c2c95203 in testing::UnitTest::Run() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/src/gtest.cc:5594:10 #37 0x7f06c2d679b8 in RUN_ALL_TESTS() /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_googletest/googletest/include/gtest/gtest.h:2334:73 #38 0x7f06c2d679b8 in main /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/xla/tests/xla_internal_test_main.cc:65:10 #39 0x7f064c0b3d8f in __libc_start_call_main csu/../sysdeps/nptl/libc_start_call_main.h:58:16 SUMMARY: AddressSanitizer: heap-buffer-overflow /root/.cache/bazel/_bazel_root/f367074f9120c6f1a67d35844ac058a3/execroot/xla/external/com_google_absl/absl/container/internal/raw_hash_set.h:990:36 in absl::lts_20230802::container_internal::CommonFields::capacity() const Shadow bytes around the buggy address: 0x5030001d9500: fd fd fd fa fa fa fd fd fd fa fa fa fd fd fd fa 0x5030001d9580: fa fa fd fd fd fd fa fa fd fd fd fd fa fa fd fd 0x5030001d9600: fd fa fa fa fd fd fd fa fa fa fd fd fd fa fa fa 0x5030001d9680: fd fd fd fd fa fa fd fd fd fa fa fa fd fd fd fa 0x5030001d9700: fa fa fd fd fd fd fa fa fd fd fd fd fa fa fd fd =>0x5030001d9780: fd fa fa fa 00 00 00 fa fa fa 00 00 00 00 fa[fa] 0x5030001d9800: 00 00 00 00 fa fa 00 00 00 00 fa fa fd fd fd fd 0x5030001d9880: fa fa fd fd fd fd fa fa fd fd fd fa fa fa fd fd 0x5030001d9900: fd fd fa fa fd fd fd fd fa fa fd fd fd fd fa fa 0x5030001d9980: fd fd fd fa fa fa fd fd fd fa fa fa fd fd fd fa 0x5030001d9a00: fa fa fd fd fd fa fa fa fd fd fd fd fa fa fd fd Shadow byte legend (one shadow byte represents 8 application bytes): Addressable: 00 Partially addressable: 01 02 03 04 05 06 07 Heap left redzone: fa Freed heap region: fd Stack left redzone: f1 Stack mid redzone: f2 Stack right redzone: f3 Stack after return: f5 Stack use after scope: f8 Global redzone: f9 Global init order: f6 Poisoned by user: f7 Container overflow: fc Array cookie: ac Intra object redzone: bb ASan internal: fe Left alloca redzone: ca Right alloca redzone: cb ==1718600==ABORTING ``` Copybara import of the project: -- 9a75d26eb9aab4226a690658d254a057fc59f22c by alekstheod <atheodor@amd.com>: Fix access memory asan issue in redzone_allocator_kernel_rocm.cu Merging this change closes tensorflow#24898 PiperOrigin-RevId: 745563669
1 parent 5cba819 commit f2b3e84

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@ namespace stream_executor {
3939
template <typename... Args>
4040
static absl::StatusOr<TypedKernel<Args...>*> LoadKernelOrGetPtr(
4141
StreamExecutor* executor, absl::string_view kernel_name, void* kernel_ptr) {
42-
using KernelPtrCacheKey =
43-
std::tuple<StreamExecutor*, absl::string_view, void*>;
42+
using KernelPtrCacheKey = std::tuple<StreamExecutor*, std::string, void*>;
4443

4544
static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit);
4645
static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) =
47-
*new absl::node_hash_map<KernelPtrCacheKey, TypedKernel<Args...>>();
46+
*new std::map<KernelPtrCacheKey, TypedKernel<Args...>>;
4847
KernelPtrCacheKey kernel_ptr_cache_key{executor, kernel_name, kernel_ptr};
4948
absl::MutexLock lock(&kernel_ptr_cache_mutex);
5049

0 commit comments

Comments
 (0)