diff --git a/clang-tools-extra/clangd/DumpAST.cpp b/clang-tools-extra/clangd/DumpAST.cpp index f9371ecf63741e..9ea17e876de7a4 100644 --- a/clang-tools-extra/clangd/DumpAST.cpp +++ b/clang-tools-extra/clangd/DumpAST.cpp @@ -145,6 +145,7 @@ class DumpVisitor : public RecursiveASTVisitor { TEMPLATE_ARGUMENT_KIND(TemplateExpansion); #undef TEMPLATE_ARGUMENT_KIND } + llvm_unreachable("Unhandled ArgKind enum"); } std::string getKind(const NestedNameSpecifierLoc &NNSL) { assert(NNSL.getNestedNameSpecifier()); @@ -161,6 +162,7 @@ class DumpVisitor : public RecursiveASTVisitor { NNS_KIND(NamespaceAlias); #undef NNS_KIND } + llvm_unreachable("Unhandled SpecifierKind enum"); } std::string getKind(const CXXCtorInitializer *CCI) { if (CCI->isBaseInitializer()) @@ -185,6 +187,7 @@ class DumpVisitor : public RecursiveASTVisitor { TEMPLATE_KIND(SubstTemplateTemplateParmPack); #undef TEMPLATE_KIND } + llvm_unreachable("Unhandled NameKind enum"); } std::string getKind(const Attr *A) { switch (A->getKind()) { @@ -194,6 +197,7 @@ class DumpVisitor : public RecursiveASTVisitor { #include "clang/Basic/AttrList.inc" #undef ATTR } + llvm_unreachable("Unhandled attr::Kind enum"); } std::string getKind(const CXXBaseSpecifier &CBS) { // There aren't really any variants of CXXBaseSpecifier. diff --git a/clang/test/CodeGen/riscv-atomics.c b/clang/test/CodeGen/RISCV/riscv-atomics.c similarity index 100% rename from clang/test/CodeGen/riscv-atomics.c rename to clang/test/CodeGen/RISCV/riscv-atomics.c diff --git a/clang/test/CodeGen/riscv-inline-asm.c b/clang/test/CodeGen/RISCV/riscv-inline-asm.c similarity index 100% rename from clang/test/CodeGen/riscv-inline-asm.c rename to clang/test/CodeGen/RISCV/riscv-inline-asm.c diff --git a/clang/test/CodeGen/riscv-metadata.c b/clang/test/CodeGen/RISCV/riscv-metadata.c similarity index 100% rename from clang/test/CodeGen/riscv-metadata.c rename to clang/test/CodeGen/RISCV/riscv-metadata.c diff --git a/clang/test/CodeGen/riscv-sdata-module-flag.c b/clang/test/CodeGen/RISCV/riscv-sdata-module-flag.c similarity index 100% rename from clang/test/CodeGen/riscv-sdata-module-flag.c rename to clang/test/CodeGen/RISCV/riscv-sdata-module-flag.c diff --git a/clang/test/CodeGen/riscv32-ilp32-abi.c b/clang/test/CodeGen/RISCV/riscv32-ilp32-abi.c similarity index 100% rename from clang/test/CodeGen/riscv32-ilp32-abi.c rename to clang/test/CodeGen/RISCV/riscv32-ilp32-abi.c diff --git a/clang/test/CodeGen/riscv32-ilp32-ilp32f-abi.c b/clang/test/CodeGen/RISCV/riscv32-ilp32-ilp32f-abi.c similarity index 100% rename from clang/test/CodeGen/riscv32-ilp32-ilp32f-abi.c rename to clang/test/CodeGen/RISCV/riscv32-ilp32-ilp32f-abi.c diff --git a/clang/test/CodeGen/riscv32-ilp32-ilp32f-ilp32d-abi.c b/clang/test/CodeGen/RISCV/riscv32-ilp32-ilp32f-ilp32d-abi.c similarity index 100% rename from clang/test/CodeGen/riscv32-ilp32-ilp32f-ilp32d-abi.c rename to clang/test/CodeGen/RISCV/riscv32-ilp32-ilp32f-ilp32d-abi.c diff --git a/clang/test/CodeGen/riscv32-ilp32d-abi.c b/clang/test/CodeGen/RISCV/riscv32-ilp32d-abi.c similarity index 100% rename from clang/test/CodeGen/riscv32-ilp32d-abi.c rename to clang/test/CodeGen/RISCV/riscv32-ilp32d-abi.c diff --git a/clang/test/CodeGen/riscv32-ilp32f-abi.c b/clang/test/CodeGen/RISCV/riscv32-ilp32f-abi.c similarity index 100% rename from clang/test/CodeGen/riscv32-ilp32f-abi.c rename to clang/test/CodeGen/RISCV/riscv32-ilp32f-abi.c diff --git a/clang/test/CodeGen/riscv32-ilp32f-ilp32d-abi.c b/clang/test/CodeGen/RISCV/riscv32-ilp32f-ilp32d-abi.c similarity index 100% rename from clang/test/CodeGen/riscv32-ilp32f-ilp32d-abi.c rename to clang/test/CodeGen/RISCV/riscv32-ilp32f-ilp32d-abi.c diff --git a/clang/test/CodeGen/riscv64-lp64-abi.c b/clang/test/CodeGen/RISCV/riscv64-lp64-abi.c similarity index 100% rename from clang/test/CodeGen/riscv64-lp64-abi.c rename to clang/test/CodeGen/RISCV/riscv64-lp64-abi.c diff --git a/clang/test/CodeGen/riscv64-lp64-lp64f-abi.c b/clang/test/CodeGen/RISCV/riscv64-lp64-lp64f-abi.c similarity index 100% rename from clang/test/CodeGen/riscv64-lp64-lp64f-abi.c rename to clang/test/CodeGen/RISCV/riscv64-lp64-lp64f-abi.c diff --git a/clang/test/CodeGen/riscv64-lp64-lp64f-lp64d-abi.c b/clang/test/CodeGen/RISCV/riscv64-lp64-lp64f-lp64d-abi.c similarity index 100% rename from clang/test/CodeGen/riscv64-lp64-lp64f-lp64d-abi.c rename to clang/test/CodeGen/RISCV/riscv64-lp64-lp64f-lp64d-abi.c diff --git a/clang/test/CodeGen/riscv64-lp64d-abi.c b/clang/test/CodeGen/RISCV/riscv64-lp64d-abi.c similarity index 100% rename from clang/test/CodeGen/riscv64-lp64d-abi.c rename to clang/test/CodeGen/RISCV/riscv64-lp64d-abi.c diff --git a/clang/test/CodeGen/riscv64-lp64f-lp64d-abi.c b/clang/test/CodeGen/RISCV/riscv64-lp64f-lp64d-abi.c similarity index 100% rename from clang/test/CodeGen/riscv64-lp64f-lp64d-abi.c rename to clang/test/CodeGen/RISCV/riscv64-lp64f-lp64d-abi.c diff --git a/clang/test/CodeGen/wasm-arguments.c b/clang/test/CodeGen/WebAssembly/wasm-arguments.c similarity index 100% rename from clang/test/CodeGen/wasm-arguments.c rename to clang/test/CodeGen/WebAssembly/wasm-arguments.c diff --git a/clang/test/CodeGen/wasm-call-main.c b/clang/test/CodeGen/WebAssembly/wasm-call-main.c similarity index 100% rename from clang/test/CodeGen/wasm-call-main.c rename to clang/test/CodeGen/WebAssembly/wasm-call-main.c diff --git a/clang/test/CodeGen/wasm-export-name.c b/clang/test/CodeGen/WebAssembly/wasm-export-name.c similarity index 100% rename from clang/test/CodeGen/wasm-export-name.c rename to clang/test/CodeGen/WebAssembly/wasm-export-name.c diff --git a/clang/test/CodeGen/wasm-import-module.c b/clang/test/CodeGen/WebAssembly/wasm-import-module.c similarity index 100% rename from clang/test/CodeGen/wasm-import-module.c rename to clang/test/CodeGen/WebAssembly/wasm-import-module.c diff --git a/clang/test/CodeGen/wasm-import-name.c b/clang/test/CodeGen/WebAssembly/wasm-import-name.c similarity index 100% rename from clang/test/CodeGen/wasm-import-name.c rename to clang/test/CodeGen/WebAssembly/wasm-import-name.c diff --git a/clang/test/CodeGen/wasm-main.c b/clang/test/CodeGen/WebAssembly/wasm-main.c similarity index 100% rename from clang/test/CodeGen/wasm-main.c rename to clang/test/CodeGen/WebAssembly/wasm-main.c diff --git a/clang/test/CodeGen/wasm-main_argc_argv.c b/clang/test/CodeGen/WebAssembly/wasm-main_argc_argv.c similarity index 100% rename from clang/test/CodeGen/wasm-main_argc_argv.c rename to clang/test/CodeGen/WebAssembly/wasm-main_argc_argv.c diff --git a/clang/test/CodeGen/wasm-regparm.c b/clang/test/CodeGen/WebAssembly/wasm-regparm.c similarity index 100% rename from clang/test/CodeGen/wasm-regparm.c rename to clang/test/CodeGen/WebAssembly/wasm-regparm.c diff --git a/clang/test/CodeGen/wasm-varargs.c b/clang/test/CodeGen/WebAssembly/wasm-varargs.c similarity index 100% rename from clang/test/CodeGen/wasm-varargs.c rename to clang/test/CodeGen/WebAssembly/wasm-varargs.c diff --git a/compiler-rt/lib/sanitizer_common/tests/sanitizer_allocator_test.cpp b/compiler-rt/lib/sanitizer_common/tests/sanitizer_allocator_test.cpp index baf9b37fb95560..26593c0c2f4960 100644 --- a/compiler-rt/lib/sanitizer_common/tests/sanitizer_allocator_test.cpp +++ b/compiler-rt/lib/sanitizer_common/tests/sanitizer_allocator_test.cpp @@ -28,6 +28,14 @@ using namespace __sanitizer; +#if SANITIZER_SOLARIS && defined(__sparcv9) +// FIXME: These tests probably fail because Solaris/sparcv9 uses the full +// 64-bit address space. Needs more investigation +#define SKIP_ON_SOLARIS_SPARCV9(x) DISABLED_##x +#else +#define SKIP_ON_SOLARIS_SPARCV9(x) x +#endif + // Too slow for debug build #if !SANITIZER_DEBUG @@ -701,7 +709,7 @@ TEST(SanitizerCommon, CombinedAllocator64VeryCompact) { } #endif -TEST(SanitizerCommon, CombinedAllocator32Compact) { +TEST(SanitizerCommon, SKIP_ON_SOLARIS_SPARCV9(CombinedAllocator32Compact)) { TestCombinedAllocator(); } @@ -937,7 +945,7 @@ TEST(SanitizerCommon, SizeClassAllocator64DynamicIteration) { #endif #endif -TEST(SanitizerCommon, SizeClassAllocator32Iteration) { +TEST(SanitizerCommon, SKIP_ON_SOLARIS_SPARCV9(SizeClassAllocator32Iteration)) { TestSizeClassAllocatorIteration(); } diff --git a/compiler-rt/lib/sanitizer_common/tests/sanitizer_stacktrace_test.cpp b/compiler-rt/lib/sanitizer_common/tests/sanitizer_stacktrace_test.cpp index afd4a0eca622c8..9a47b4e113846c 100644 --- a/compiler-rt/lib/sanitizer_common/tests/sanitizer_stacktrace_test.cpp +++ b/compiler-rt/lib/sanitizer_common/tests/sanitizer_stacktrace_test.cpp @@ -70,11 +70,18 @@ void FastUnwindTest::TearDown() { #if SANITIZER_CAN_FAST_UNWIND +#ifdef __sparc__ +// Fake stacks don't meet SPARC UnwindFast requirements. +#define SKIP_ON_SPARC(x) DISABLED_##x +#else +#define SKIP_ON_SPARC(x) x +#endif + void FastUnwindTest::UnwindFast() { trace.UnwindFast(start_pc, fake_bp, fake_top, fake_bottom, kStackTraceMax); } -TEST_F(FastUnwindTest, Basic) { +TEST_F(FastUnwindTest, SKIP_ON_SPARC(Basic)) { UnwindFast(); // Should get all on-stack retaddrs and start_pc. EXPECT_EQ(6U, trace.size); @@ -85,7 +92,7 @@ TEST_F(FastUnwindTest, Basic) { } // From: https://github.com/google/sanitizers/issues/162 -TEST_F(FastUnwindTest, FramePointerLoop) { +TEST_F(FastUnwindTest, SKIP_ON_SPARC(FramePointerLoop)) { // Make one fp point to itself. fake_stack[4] = (uhwptr)&fake_stack[4]; UnwindFast(); @@ -97,7 +104,7 @@ TEST_F(FastUnwindTest, FramePointerLoop) { } } -TEST_F(FastUnwindTest, MisalignedFramePointer) { +TEST_F(FastUnwindTest, SKIP_ON_SPARC(MisalignedFramePointer)) { // Make one fp misaligned. fake_stack[4] += 3; UnwindFast(); @@ -122,7 +129,7 @@ TEST_F(FastUnwindTest, ZeroFramesStackTrace) { EXPECT_EQ(0U, trace.top_frame_bp); } -TEST_F(FastUnwindTest, FPBelowPrevFP) { +TEST_F(FastUnwindTest, SKIP_ON_SPARC(FPBelowPrevFP)) { // The next FP points to unreadable memory inside the stack limits, but below // current FP. fake_stack[0] = (uhwptr)&fake_stack[-50]; @@ -133,7 +140,7 @@ TEST_F(FastUnwindTest, FPBelowPrevFP) { EXPECT_EQ(PC(1), trace.trace[1]); } -TEST_F(FastUnwindTest, CloseToZeroFrame) { +TEST_F(FastUnwindTest, SKIP_ON_SPARC(CloseToZeroFrame)) { // Make one pc a NULL pointer. fake_stack[5] = 0x0; UnwindFast(); diff --git a/lldb/docs/lldb-gdb-remote.txt b/lldb/docs/lldb-gdb-remote.txt index 91f6a4d12c2e55..7d333c8c8e6e87 100644 --- a/lldb/docs/lldb-gdb-remote.txt +++ b/lldb/docs/lldb-gdb-remote.txt @@ -1123,6 +1123,11 @@ tuples to return are: // the file while for anonymous regions it have to be the name // associated to the region if that is available. + flags:; // where is a space separated string + // of flag names. Currently the only supported flag + // is "mt" for AArch64 memory tagging. lldb will + // ignore any other flags in this field. + error:; // where is // a hex encoded string value that // contains an error string diff --git a/lldb/docs/use/qemu-testing.rst b/lldb/docs/use/qemu-testing.rst index a82dfb23a16a48..a523137c8710be 100644 --- a/lldb/docs/use/qemu-testing.rst +++ b/lldb/docs/use/qemu-testing.rst @@ -93,6 +93,9 @@ run-qemu.sh has following dependencies: * --sve option will enable AArch64 SVE mode. +* --mte option will enable AArch64 MTE (memory tagging) mode. + (can be used on its own or in addition to --sve) + **Example:** Run QEMU Arm or AArch64 system emulation using run-qemu.sh :: diff --git a/lldb/include/lldb/Target/MemoryRegionInfo.h b/lldb/include/lldb/Target/MemoryRegionInfo.h index a22da8d72b83e7..19c6c17ef90129 100644 --- a/lldb/include/lldb/Target/MemoryRegionInfo.h +++ b/lldb/include/lldb/Target/MemoryRegionInfo.h @@ -24,16 +24,17 @@ class MemoryRegionInfo { MemoryRegionInfo() = default; MemoryRegionInfo(RangeType range, OptionalBool read, OptionalBool write, OptionalBool execute, OptionalBool mapped, ConstString name, - OptionalBool flash, lldb::offset_t blocksize) + OptionalBool flash, lldb::offset_t blocksize, + OptionalBool memory_tagged) : m_range(range), m_read(read), m_write(write), m_execute(execute), - m_mapped(mapped), m_name(name), m_flash(flash), m_blocksize(blocksize) { - } + m_mapped(mapped), m_name(name), m_flash(flash), m_blocksize(blocksize), + m_memory_tagged(memory_tagged) {} RangeType &GetRange() { return m_range; } void Clear() { m_range.Clear(); - m_read = m_write = m_execute = eDontKnow; + m_read = m_write = m_execute = m_memory_tagged = eDontKnow; } const RangeType &GetRange() const { return m_range; } @@ -48,6 +49,8 @@ class MemoryRegionInfo { ConstString GetName() const { return m_name; } + OptionalBool GetMemoryTagged() const { return m_memory_tagged; } + void SetReadable(OptionalBool val) { m_read = val; } void SetWritable(OptionalBool val) { m_write = val; } @@ -66,6 +69,8 @@ class MemoryRegionInfo { void SetBlocksize(lldb::offset_t blocksize) { m_blocksize = blocksize; } + void SetMemoryTagged(OptionalBool val) { m_memory_tagged = val; } + // Get permissions as a uint32_t that is a mask of one or more bits from the // lldb::Permissions uint32_t GetLLDBPermissions() const { @@ -91,7 +96,8 @@ class MemoryRegionInfo { return m_range == rhs.m_range && m_read == rhs.m_read && m_write == rhs.m_write && m_execute == rhs.m_execute && m_mapped == rhs.m_mapped && m_name == rhs.m_name && - m_flash == rhs.m_flash && m_blocksize == rhs.m_blocksize; + m_flash == rhs.m_flash && m_blocksize == rhs.m_blocksize && + m_memory_tagged == rhs.m_memory_tagged; } bool operator!=(const MemoryRegionInfo &rhs) const { return !(*this == rhs); } @@ -105,6 +111,7 @@ class MemoryRegionInfo { ConstString m_name; OptionalBool m_flash = eDontKnow; lldb::offset_t m_blocksize = 0; + OptionalBool m_memory_tagged = eDontKnow; }; inline bool operator<(const MemoryRegionInfo &lhs, diff --git a/lldb/packages/Python/lldbsuite/test/lldbtest.py b/lldb/packages/Python/lldbsuite/test/lldbtest.py index a02c445a937a6f..7ba3a154db4a2d 100644 --- a/lldb/packages/Python/lldbsuite/test/lldbtest.py +++ b/lldb/packages/Python/lldbsuite/test/lldbtest.py @@ -1318,6 +1318,30 @@ def isAArch64SVE(self): return " sve " in cpuinfo + def hasLinuxVmFlags(self): + """ Check that the target machine has "VmFlags" lines in + its /proc/{pid}/smaps files.""" + + triple = self.dbg.GetSelectedPlatform().GetTriple() + if not re.match(".*-.*-linux", triple): + return False + + self.runCmd('platform process list') + pid = None + for line in self.res.GetOutput().splitlines(): + if 'lldb-server' in line: + pid = line.split(' ')[0] + break + + if pid is None: + return False + + smaps_path = self.getBuildArtifact('smaps') + self.runCmd('platform get-file "/proc/{}/smaps" {}'.format(pid, smaps_path)) + + with open(smaps_path, 'r') as f: + return "VmFlags" in f.read() + def getArchitecture(self): """Returns the architecture in effect the test suite is running with.""" module = builder_module() diff --git a/lldb/packages/Python/lldbsuite/test/tools/lldb-server/gdbremote_testcase.py b/lldb/packages/Python/lldbsuite/test/tools/lldb-server/gdbremote_testcase.py index a0e3cb36294428..ce700d9403f48b 100644 --- a/lldb/packages/Python/lldbsuite/test/tools/lldb-server/gdbremote_testcase.py +++ b/lldb/packages/Python/lldbsuite/test/tools/lldb-server/gdbremote_testcase.py @@ -727,13 +727,13 @@ def parse_memory_region_packet(self, context): # Validate keys are known. for (key, val) in list(mem_region_dict.items()): - self.assertTrue( - key in [ - "start", - "size", - "permissions", - "name", - "error"]) + self.assertIn(key, + ["start", + "size", + "permissions", + "flags", + "name", + "error"]) self.assertIsNotNone(val) mem_region_dict["name"] = seven.unhexlify(mem_region_dict.get("name", "")) diff --git a/lldb/scripts/lldb-test-qemu/run-qemu.sh b/lldb/scripts/lldb-test-qemu/run-qemu.sh index cb28b7aaf6420e..339b8d955e6134 100644 --- a/lldb/scripts/lldb-test-qemu/run-qemu.sh +++ b/lldb/scripts/lldb-test-qemu/run-qemu.sh @@ -5,7 +5,8 @@ print_usage() { echo -e "Starts QEMU system mode emulation for the architecture.\n" echo -e " --help\t\t\tDisplay this information." echo -e " --arch {arm|arm64}\t\tSelects architecture QEMU system emulation." - echo -e " --sve {path}\t\t\tEnables AArch64 SVE mode.\n" + echo -e " --sve\t\t\t\tEnables AArch64 SVE mode." + echo -e " --mte\t\t\t\tEnables AArch64 MTE mode.\n" echo -e " --rootfs {path}\t\tPath of root file system image." echo -e " --qemu {path}\t\t\tPath of pre-installed qemu-system-* executable." echo -e " --kernel {path}\t\tPath of Linux kernel prebuilt image.\n" @@ -48,6 +49,7 @@ while [[ $# -gt 0 ]]; do --kernel) KERNEL_IMG=$2; shift;; --qemu) QEMU_BIN=$2; shift;; --sve) SVE=1;; + --mte) MTE=1;; --help) print_usage 0 ;; *) invalid_arg "$1" ;; esac @@ -99,6 +101,9 @@ if [[ "$ARCH" == "arm" ]]; then if [[ $SVE ]]; then echo "warning: --sve is supported by AArch64 targets only" fi + if [[ $MTE ]]; then + echo "warning: --mte is supported by AArch64 targets only" + fi elif [[ "$ARCH" == "arm64" ]]; then QEMU_MACHINE=virt QEMU_SVE_MAX_VQ=4 @@ -107,6 +112,9 @@ elif [[ "$ARCH" == "arm64" ]]; then if [[ $SVE ]]; then QEMU_CPU="max,sve-max-vq=$QEMU_SVE_MAX_VQ" fi + if [[ $MTE ]]; then + QEMU_MACHINE="$QEMU_MACHINE,mte=on" + fi fi run_qemu diff --git a/lldb/source/Commands/CommandObjectMemory.cpp b/lldb/source/Commands/CommandObjectMemory.cpp index 20a1fbb0f1b243..7d5c642d0131c0 100644 --- a/lldb/source/Commands/CommandObjectMemory.cpp +++ b/lldb/source/Commands/CommandObjectMemory.cpp @@ -1709,12 +1709,18 @@ class CommandObjectMemoryRegion : public CommandObjectParsed { section_name = section_sp->GetName(); } } + result.AppendMessageWithFormatv( - "[{0:x16}-{1:x16}) {2:r}{3:w}{4:x}{5}{6}{7}{8}\n", + "[{0:x16}-{1:x16}) {2:r}{3:w}{4:x}{5}{6}{7}{8}", range_info.GetRange().GetRangeBase(), range_info.GetRange().GetRangeEnd(), range_info.GetReadable(), range_info.GetWritable(), range_info.GetExecutable(), name ? " " : "", name, section_name ? " " : "", section_name); + MemoryRegionInfo::OptionalBool memory_tagged = + range_info.GetMemoryTagged(); + if (memory_tagged == MemoryRegionInfo::OptionalBool::eYes) + result.AppendMessage("memory tagging: enabled"); + m_prev_end_addr = range_info.GetRange().GetRangeEnd(); result.SetStatus(eReturnStatusSuccessFinishResult); return true; diff --git a/lldb/source/Plugins/ObjectFile/Mach-O/ObjectFileMachO.cpp b/lldb/source/Plugins/ObjectFile/Mach-O/ObjectFileMachO.cpp index 2653290ea8c7ff..aafd5ab746b3c8 100644 --- a/lldb/source/Plugins/ObjectFile/Mach-O/ObjectFileMachO.cpp +++ b/lldb/source/Plugins/ObjectFile/Mach-O/ObjectFileMachO.cpp @@ -3467,10 +3467,11 @@ size_t ObjectFileMachO::ParseSymtab() { sym[sym_idx].GetMangled().SetValue( const_symbol_name, symbol_name_is_mangled); if (is_gsym && is_debug) { - const char *gsym_name = sym[sym_idx] - .GetMangled() - .GetName() - .GetCString(); + const char *gsym_name = + sym[sym_idx] + .GetMangled() + .GetName(Mangled::ePreferMangled) + .GetCString(); if (gsym_name) N_GSYM_name_to_sym_idx[gsym_name] = sym_idx; } @@ -3550,8 +3551,10 @@ size_t ObjectFileMachO::ParseSymtab() { bool found_it = false; for (auto pos = range.first; pos != range.second; ++pos) { - if (sym[sym_idx].GetMangled().GetName() == - sym[pos->second].GetMangled().GetName()) { + if (sym[sym_idx].GetMangled().GetName( + Mangled::ePreferMangled) == + sym[pos->second].GetMangled().GetName( + Mangled::ePreferMangled)) { m_nlist_idx_to_sym_idx[nlist_idx] = pos->second; // We just need the flags from the linker // symbol, so put these flags @@ -3591,8 +3594,10 @@ size_t ObjectFileMachO::ParseSymtab() { bool found_it = false; for (auto pos = range.first; pos != range.second; ++pos) { - if (sym[sym_idx].GetMangled().GetName() == - sym[pos->second].GetMangled().GetName()) { + if (sym[sym_idx].GetMangled().GetName( + Mangled::ePreferMangled) == + sym[pos->second].GetMangled().GetName( + Mangled::ePreferMangled)) { m_nlist_idx_to_sym_idx[nlist_idx] = pos->second; // We just need the flags from the linker // symbol, so put these flags @@ -3610,10 +3615,11 @@ size_t ObjectFileMachO::ParseSymtab() { if (found_it) continue; } else { - const char *gsym_name = sym[sym_idx] - .GetMangled() - .GetName() - .GetCString(); + const char *gsym_name = + sym[sym_idx] + .GetMangled() + .GetName(Mangled::ePreferMangled) + .GetCString(); if (gsym_name) { // Combine N_GSYM stab entries with the non // stab symbol @@ -4334,8 +4340,10 @@ size_t ObjectFileMachO::ParseSymtab() { } if (is_gsym) { - const char *gsym_name = - sym[sym_idx].GetMangled().GetName().GetCString(); + const char *gsym_name = sym[sym_idx] + .GetMangled() + .GetName(Mangled::ePreferMangled) + .GetCString(); if (gsym_name) N_GSYM_name_to_sym_idx[gsym_name] = sym_idx; } @@ -4399,8 +4407,9 @@ size_t ObjectFileMachO::ParseSymtab() { if (range.first != range.second) { for (ValueToSymbolIndexMap::const_iterator pos = range.first; pos != range.second; ++pos) { - if (sym[sym_idx].GetMangled().GetName() == - sym[pos->second].GetMangled().GetName()) { + if (sym[sym_idx].GetMangled().GetName(Mangled::ePreferMangled) == + sym[pos->second].GetMangled().GetName( + Mangled::ePreferMangled)) { m_nlist_idx_to_sym_idx[nlist_idx] = pos->second; // We just need the flags from the linker symbol, so put these // flags into the N_FUN flags to avoid duplicate symbols in the @@ -4433,8 +4442,9 @@ size_t ObjectFileMachO::ParseSymtab() { if (range.first != range.second) { for (ValueToSymbolIndexMap::const_iterator pos = range.first; pos != range.second; ++pos) { - if (sym[sym_idx].GetMangled().GetName() == - sym[pos->second].GetMangled().GetName()) { + if (sym[sym_idx].GetMangled().GetName(Mangled::ePreferMangled) == + sym[pos->second].GetMangled().GetName( + Mangled::ePreferMangled)) { m_nlist_idx_to_sym_idx[nlist_idx] = pos->second; // We just need the flags from the linker symbol, so put these // flags into the N_STSYM flags to avoid duplicate symbols in @@ -4447,8 +4457,10 @@ size_t ObjectFileMachO::ParseSymtab() { } } else { // Combine N_GSYM stab entries with the non stab symbol. - const char *gsym_name = - sym[sym_idx].GetMangled().GetName().GetCString(); + const char *gsym_name = sym[sym_idx] + .GetMangled() + .GetName(Mangled::ePreferMangled) + .GetCString(); if (gsym_name) { ConstNameToSymbolIndexMap::const_iterator pos = N_GSYM_name_to_sym_idx.find(gsym_name); diff --git a/lldb/source/Plugins/Process/Linux/NativeProcessLinux.cpp b/lldb/source/Plugins/Process/Linux/NativeProcessLinux.cpp index 9883e1cfc5327f..e07d763c2de773 100644 --- a/lldb/source/Plugins/Process/Linux/NativeProcessLinux.cpp +++ b/lldb/source/Plugins/Process/Linux/NativeProcessLinux.cpp @@ -1297,26 +1297,36 @@ Status NativeProcessLinux::PopulateMemoryRegionCache() { return Status(); } - auto BufferOrError = getProcFile(GetID(), "maps"); - if (!BufferOrError) { + Status Result; + LinuxMapCallback callback = [&](llvm::Expected Info) { + if (Info) { + FileSpec file_spec(Info->GetName().GetCString()); + FileSystem::Instance().Resolve(file_spec); + m_mem_region_cache.emplace_back(*Info, file_spec); + return true; + } + + Result = Info.takeError(); m_supports_mem_region = LazyBool::eLazyBoolNo; - return BufferOrError.getError(); + LLDB_LOG(log, "failed to parse proc maps: {0}", Result); + return false; + }; + + // Linux kernel since 2.6.14 has /proc/{pid}/smaps + // if CONFIG_PROC_PAGE_MONITOR is enabled + auto BufferOrError = getProcFile(GetID(), "smaps"); + if (BufferOrError) + ParseLinuxSMapRegions(BufferOrError.get()->getBuffer(), callback); + else { + BufferOrError = getProcFile(GetID(), "maps"); + if (!BufferOrError) { + m_supports_mem_region = LazyBool::eLazyBoolNo; + return BufferOrError.getError(); + } + + ParseLinuxMapRegions(BufferOrError.get()->getBuffer(), callback); } - Status Result; - ParseLinuxMapRegions(BufferOrError.get()->getBuffer(), - [&](const MemoryRegionInfo &Info, const Status &ST) { - if (ST.Success()) { - FileSpec file_spec(Info.GetName().GetCString()); - FileSystem::Instance().Resolve(file_spec); - m_mem_region_cache.emplace_back(Info, file_spec); - return true; - } else { - m_supports_mem_region = LazyBool::eLazyBoolNo; - LLDB_LOG(log, "failed to parse proc maps: {0}", ST); - Result = ST; - return false; - } - }); + if (Result.Fail()) return Result; diff --git a/lldb/source/Plugins/Process/Utility/LinuxProcMaps.cpp b/lldb/source/Plugins/Process/Utility/LinuxProcMaps.cpp index 0c7d9ddc5ac6bd..947b970edf6cc6 100644 --- a/lldb/source/Plugins/Process/Utility/LinuxProcMaps.cpp +++ b/lldb/source/Plugins/Process/Utility/LinuxProcMaps.cpp @@ -7,80 +7,93 @@ //===----------------------------------------------------------------------===// #include "LinuxProcMaps.h" -#include "llvm/ADT/StringRef.h" #include "lldb/Target/MemoryRegionInfo.h" #include "lldb/Utility/Status.h" #include "lldb/Utility/StringExtractor.h" +#include "llvm/ADT/StringRef.h" using namespace lldb_private; -static Status +enum class MapsKind { Maps, SMaps }; + +static llvm::Expected ProcMapError(const char *msg, + MapsKind kind) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), msg, + kind == MapsKind::Maps ? "maps" : "smaps"); +} + +static llvm::Expected ParseMemoryRegionInfoFromProcMapsLine(llvm::StringRef maps_line, - MemoryRegionInfo &memory_region_info) { - memory_region_info.Clear(); - + MapsKind maps_kind) { + MemoryRegionInfo region; StringExtractor line_extractor(maps_line); - + // Format: {address_start_hex}-{address_end_hex} perms offset dev inode // pathname perms: rwxp (letter is present if set, '-' if not, final // character is p=private, s=shared). - + // Parse out the starting address lldb::addr_t start_address = line_extractor.GetHexMaxU64(false, 0); - + // Parse out hyphen separating start and end address from range. if (!line_extractor.GetBytesLeft() || (line_extractor.GetChar() != '-')) - return Status( - "malformed /proc/{pid}/maps entry, missing dash between address range"); - + return ProcMapError( + "malformed /proc/{pid}/%s entry, missing dash between address range", + maps_kind); + // Parse out the ending address lldb::addr_t end_address = line_extractor.GetHexMaxU64(false, start_address); - + // Parse out the space after the address. if (!line_extractor.GetBytesLeft() || (line_extractor.GetChar() != ' ')) - return Status( - "malformed /proc/{pid}/maps entry, missing space after range"); - + return ProcMapError( + "malformed /proc/{pid}/%s entry, missing space after range", maps_kind); + // Save the range. - memory_region_info.GetRange().SetRangeBase(start_address); - memory_region_info.GetRange().SetRangeEnd(end_address); - - // Any memory region in /proc/{pid}/maps is by definition mapped into the - // process. - memory_region_info.SetMapped(MemoryRegionInfo::OptionalBool::eYes); - + region.GetRange().SetRangeBase(start_address); + region.GetRange().SetRangeEnd(end_address); + + // Any memory region in /proc/{pid}/(maps|smaps) is by definition mapped + // into the process. + region.SetMapped(MemoryRegionInfo::OptionalBool::eYes); + // Parse out each permission entry. if (line_extractor.GetBytesLeft() < 4) - return Status("malformed /proc/{pid}/maps entry, missing some portion of " - "permissions"); - + return ProcMapError( + "malformed /proc/{pid}/%s entry, missing some portion of " + "permissions", + maps_kind); + // Handle read permission. const char read_perm_char = line_extractor.GetChar(); if (read_perm_char == 'r') - memory_region_info.SetReadable(MemoryRegionInfo::OptionalBool::eYes); + region.SetReadable(MemoryRegionInfo::OptionalBool::eYes); else if (read_perm_char == '-') - memory_region_info.SetReadable(MemoryRegionInfo::OptionalBool::eNo); + region.SetReadable(MemoryRegionInfo::OptionalBool::eNo); else - return Status("unexpected /proc/{pid}/maps read permission char"); - + return ProcMapError("unexpected /proc/{pid}/%s read permission char", + maps_kind); + // Handle write permission. const char write_perm_char = line_extractor.GetChar(); if (write_perm_char == 'w') - memory_region_info.SetWritable(MemoryRegionInfo::OptionalBool::eYes); + region.SetWritable(MemoryRegionInfo::OptionalBool::eYes); else if (write_perm_char == '-') - memory_region_info.SetWritable(MemoryRegionInfo::OptionalBool::eNo); + region.SetWritable(MemoryRegionInfo::OptionalBool::eNo); else - return Status("unexpected /proc/{pid}/maps write permission char"); - + return ProcMapError("unexpected /proc/{pid}/%s write permission char", + maps_kind); + // Handle execute permission. const char exec_perm_char = line_extractor.GetChar(); if (exec_perm_char == 'x') - memory_region_info.SetExecutable(MemoryRegionInfo::OptionalBool::eYes); + region.SetExecutable(MemoryRegionInfo::OptionalBool::eYes); else if (exec_perm_char == '-') - memory_region_info.SetExecutable(MemoryRegionInfo::OptionalBool::eNo); + region.SetExecutable(MemoryRegionInfo::OptionalBool::eNo); else - return Status("unexpected /proc/{pid}/maps exec permission char"); - + return ProcMapError("unexpected /proc/{pid}/%s exec permission char", + maps_kind); + line_extractor.GetChar(); // Read the private bit line_extractor.SkipSpaces(); // Skip the separator line_extractor.GetHexMaxU64(false, 0); // Read the offset @@ -89,13 +102,13 @@ ParseMemoryRegionInfoFromProcMapsLine(llvm::StringRef maps_line, line_extractor.GetHexMaxU64(false, 0); // Read the major device number line_extractor.SkipSpaces(); // Skip the separator line_extractor.GetU64(0, 10); // Read the inode number - + line_extractor.SkipSpaces(); const char *name = line_extractor.Peek(); if (name) - memory_region_info.SetName(name); - - return Status(); + region.SetName(name); + + return region; } void lldb_private::ParseLinuxMapRegions(llvm::StringRef linux_map, @@ -104,9 +117,80 @@ void lldb_private::ParseLinuxMapRegions(llvm::StringRef linux_map, llvm::StringRef line; while (!lines.empty()) { std::tie(line, lines) = lines.split('\n'); - MemoryRegionInfo region; - Status error = ParseMemoryRegionInfoFromProcMapsLine(line, region); - if (!callback(region, error)) + if (!callback(ParseMemoryRegionInfoFromProcMapsLine(line, MapsKind::Maps))) break; } } + +void lldb_private::ParseLinuxSMapRegions(llvm::StringRef linux_smap, + LinuxMapCallback const &callback) { + // Entries in /smaps look like: + // 00400000-0048a000 r-xp 00000000 fd:03 960637 + // Size: 552 kB + // Rss: 460 kB + // <...> + // VmFlags: rd ex mr mw me dw + // 00500000-0058a000 rwxp 00000000 fd:03 960637 + // <...> + // + // Where the first line is identical to the /maps format + // and VmFlags is only printed for kernels >= 3.8. + + llvm::StringRef lines(linux_smap); + llvm::StringRef line; + llvm::Optional region; + + while (lines.size()) { + std::tie(line, lines) = lines.split('\n'); + + // A property line looks like: + // : + // (no spaces on the left hand side) + // A header will have a ':' but the LHS will contain spaces + llvm::StringRef name; + llvm::StringRef value; + std::tie(name, value) = line.split(':'); + + // If this line is a property line + if (!name.contains(' ')) { + if (region) { + if (name == "VmFlags") { + if (value.contains("mt")) + region->SetMemoryTagged(MemoryRegionInfo::eYes); + else + region->SetMemoryTagged(MemoryRegionInfo::eNo); + } + // Ignore anything else + } else { + // Orphaned settings line + callback(ProcMapError( + "Found a property line without a corresponding mapping " + "in /proc/{pid}/%s", + MapsKind::SMaps)); + return; + } + } else { + // Must be a new region header + if (region) { + // Save current region + callback(*region); + region.reset(); + } + + // Try to start a new region + llvm::Expected new_region = + ParseMemoryRegionInfoFromProcMapsLine(line, MapsKind::SMaps); + if (new_region) { + region = *new_region; + } else { + // Stop at first invalid region header + callback(new_region.takeError()); + return; + } + } + } + + // Catch last region + if (region) + callback(*region); +} diff --git a/lldb/source/Plugins/Process/Utility/LinuxProcMaps.h b/lldb/source/Plugins/Process/Utility/LinuxProcMaps.h index 363f248fd416e2..02f78d55c29053 100644 --- a/lldb/source/Plugins/Process/Utility/LinuxProcMaps.h +++ b/lldb/source/Plugins/Process/Utility/LinuxProcMaps.h @@ -11,16 +11,16 @@ #include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" -#include - +#include "llvm/Support/Error.h" namespace lldb_private { -typedef std::function LinuxMapCallback; +typedef std::function)> LinuxMapCallback; void ParseLinuxMapRegions(llvm::StringRef linux_map, LinuxMapCallback const &callback); +void ParseLinuxSMapRegions(llvm::StringRef linux_smap, + LinuxMapCallback const &callback); } // namespace lldb_private diff --git a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationClient.cpp b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationClient.cpp index b1552a3a43adaf..d375a312ae2ce9 100644 --- a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationClient.cpp +++ b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationClient.cpp @@ -1529,6 +1529,22 @@ Status GDBRemoteCommunicationClient::GetMemoryRegionInfo( std::string name; name_extractor.GetHexByteString(name); region_info.SetName(name.c_str()); + } else if (name.equals("flags")) { + region_info.SetMemoryTagged(MemoryRegionInfo::eNo); + + llvm::StringRef flags = value; + llvm::StringRef flag; + while (flags.size()) { + flags = flags.ltrim(); + std::tie(flag, flags) = flags.split(' '); + // To account for trailing whitespace + if (flag.size()) { + if (flag == "mt") { + region_info.SetMemoryTagged(MemoryRegionInfo::eYes); + break; + } + } + } } else if (name.equals("error")) { StringExtractorGDBRemote error_extractor(value); std::string error_string; diff --git a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationServerLLGS.cpp b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationServerLLGS.cpp index 2e57d7e3ecae63..2cf88c0d9f7067 100644 --- a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationServerLLGS.cpp +++ b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationServerLLGS.cpp @@ -2605,6 +2605,17 @@ GDBRemoteCommunicationServerLLGS::Handle_qMemoryRegionInfo( response.PutChar(';'); } + // Flags + MemoryRegionInfo::OptionalBool memory_tagged = + region_info.GetMemoryTagged(); + if (memory_tagged != MemoryRegionInfo::eDontKnow) { + response.PutCString("flags:"); + if (memory_tagged == MemoryRegionInfo::eYes) { + response.PutCString("mt"); + } + response.PutChar(';'); + } + // Name ConstString name = region_info.GetName(); if (name) { diff --git a/lldb/source/Plugins/Process/minidump/MinidumpParser.cpp b/lldb/source/Plugins/Process/minidump/MinidumpParser.cpp index 9108d4d18aa5c5..e16f86cca1c21d 100644 --- a/lldb/source/Plugins/Process/minidump/MinidumpParser.cpp +++ b/lldb/source/Plugins/Process/minidump/MinidumpParser.cpp @@ -263,13 +263,18 @@ CreateRegionsCacheFromLinuxMaps(MinidumpParser &parser, auto data = parser.GetStream(StreamType::LinuxMaps); if (data.empty()) return false; - ParseLinuxMapRegions(llvm::toStringRef(data), - [&](const lldb_private::MemoryRegionInfo ®ion, - const lldb_private::Status &status) -> bool { - if (status.Success()) - regions.push_back(region); - return true; - }); + + Log *log = lldb_private::GetLogIfAllCategoriesSet(LIBLLDB_LOG_EXPRESSIONS); + ParseLinuxMapRegions( + llvm::toStringRef(data), + [®ions, &log](llvm::Expected region) -> bool { + if (region) + regions.push_back(*region); + else + LLDB_LOG_ERROR(log, region.takeError(), + "Reading memory region from minidump failed: {0}"); + return true; + }); return !regions.empty(); } diff --git a/lldb/source/Target/MemoryRegionInfo.cpp b/lldb/source/Target/MemoryRegionInfo.cpp index c7fb349ee1cc41..0d5ebbdbe23800 100644 --- a/lldb/source/Target/MemoryRegionInfo.cpp +++ b/lldb/source/Target/MemoryRegionInfo.cpp @@ -13,12 +13,12 @@ using namespace lldb_private; llvm::raw_ostream &lldb_private::operator<<(llvm::raw_ostream &OS, const MemoryRegionInfo &Info) { return OS << llvm::formatv("MemoryRegionInfo([{0}, {1}), {2:r}{3:w}{4:x}, " - "{5}, `{6}`, {7}, {8})", + "{5}, `{6}`, {7}, {8}, {9})", Info.GetRange().GetRangeBase(), Info.GetRange().GetRangeEnd(), Info.GetReadable(), Info.GetWritable(), Info.GetExecutable(), Info.GetMapped(), Info.GetName(), Info.GetFlash(), - Info.GetBlocksize()); + Info.GetBlocksize(), Info.GetMemoryTagged()); } void llvm::format_provider::format( diff --git a/lldb/test/API/linux/aarch64/mte_memory_region/Makefile b/lldb/test/API/linux/aarch64/mte_memory_region/Makefile new file mode 100644 index 00000000000000..10495940055b63 --- /dev/null +++ b/lldb/test/API/linux/aarch64/mte_memory_region/Makefile @@ -0,0 +1,3 @@ +C_SOURCES := main.c + +include Makefile.rules diff --git a/lldb/test/API/linux/aarch64/mte_memory_region/TestAArch64LinuxMTEMemoryRegion.py b/lldb/test/API/linux/aarch64/mte_memory_region/TestAArch64LinuxMTEMemoryRegion.py new file mode 100644 index 00000000000000..ff8e01cb28c8c3 --- /dev/null +++ b/lldb/test/API/linux/aarch64/mte_memory_region/TestAArch64LinuxMTEMemoryRegion.py @@ -0,0 +1,55 @@ +""" +Test that "memory region" command can show memory tagged regions +on AArch64 Linux. +""" + + + +import lldb +from lldbsuite.test.decorators import * +from lldbsuite.test.lldbtest import * +from lldbsuite.test import lldbutil + + +class AArch64LinuxMTEMemoryRegionTestCase(TestBase): + + mydir = TestBase.compute_mydir(__file__) + + NO_DEBUG_INFO_TESTCASE = True + + @skipIf(archs=no_match(["aarch64"])) + @skipUnlessPlatform(["linux"]) + def test_mte_regions(self): + if not self.hasLinuxVmFlags(): + self.skipTest('/proc/{pid}/smaps VmFlags must be present') + + self.build() + self.runCmd("file " + self.getBuildArtifact("a.out"), CURRENT_EXECUTABLE_SET) + + lldbutil.run_break_set_by_file_and_line(self, "main.c", + line_number('main.c', '// Set break point at this line.'), + num_expected_locations=1) + + self.runCmd("run", RUN_SUCCEEDED) + + if self.process().GetState() == lldb.eStateExited: + # 47 = non MTE toolchain + # 48 = non MTE target + exit_status = self.process().GetExitStatus() + if exit_status == 47: + self.skipTest("MTE must be available in toolchain") + elif exit_status == 48: + self.skipTest("target must have MTE enabled") + + # Otherwise we have MTE but another problem occured + self.fail("Test program failed to run.") + + self.expect("thread list", STOPPED_DUE_TO_BREAKPOINT, + substrs=['stopped', + 'stop reason = breakpoint']) + + substrs = ["memory tagging: enabled"] + # The new page will be tagged + self.expect("memory region the_page", substrs=substrs) + # Code page will not be + self.expect("memory region main", substrs=substrs, matching=False) diff --git a/lldb/test/API/linux/aarch64/mte_memory_region/main.c b/lldb/test/API/linux/aarch64/mte_memory_region/main.c new file mode 100644 index 00000000000000..17c135dc3344b6 --- /dev/null +++ b/lldb/test/API/linux/aarch64/mte_memory_region/main.c @@ -0,0 +1,44 @@ +#include +#include +#include +#include +#include +#include + +#define INCOMPATIBLE_TOOLCHAIN 47 +#define INCOMPATIBLE_TARGET 48 + +// This is in a seperate non static function +// so that we can always breakpoint the return 0 here. +// Even if main never reaches it because HWCAP2_MTE +// is not defined. +// If it were in main then you would effectively have: +// return TEST_INCOMPATIBLE; +// return 0; +// So the two returns would have the same breakpoint location +// and we couldn't tell them apart. +int setup_mte_page(void) { +#ifdef HWCAP2_MTE + if (!(getauxval(AT_HWCAP2) & HWCAP2_MTE)) + return INCOMPATIBLE_TARGET; + + int got = prctl(PR_SET_TAGGED_ADDR_CTRL, PR_TAGGED_ADDR_ENABLE, 0, 0, 0); + if (got) + return 1; + + void *the_page = mmap(0, sysconf(_SC_PAGESIZE), PROT_MTE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (the_page == MAP_FAILED) + return 1; +#endif + + return 0; // Set break point at this line. +} + +int main(int argc, char const *argv[]) { +#ifdef HWCAP2_MTE + return setup_mte_page(); +#else + return INCOMPATIBLE_TOOLCHAIN; +#endif +} diff --git a/lldb/unittests/Process/Utility/CMakeLists.txt b/lldb/unittests/Process/Utility/CMakeLists.txt index 0041a94a79a3bc..9d827582b3cf99 100644 --- a/lldb/unittests/Process/Utility/CMakeLists.txt +++ b/lldb/unittests/Process/Utility/CMakeLists.txt @@ -1,5 +1,6 @@ add_lldb_unittest(ProcessUtilityTests RegisterContextFreeBSDTest.cpp + LinuxProcMapsTest.cpp LINK_LIBS lldbPluginProcessUtility) diff --git a/lldb/unittests/Process/Utility/LinuxProcMapsTest.cpp b/lldb/unittests/Process/Utility/LinuxProcMapsTest.cpp new file mode 100644 index 00000000000000..203875533d93a7 --- /dev/null +++ b/lldb/unittests/Process/Utility/LinuxProcMapsTest.cpp @@ -0,0 +1,262 @@ +//===-- LinuxProcMapsTest.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "Plugins/Process/Utility/LinuxProcMaps.h" +#include "lldb/Target/MemoryRegionInfo.h" +#include "lldb/Utility/Status.h" +#include + +using namespace lldb_private; + +typedef std::tuple + LinuxProcMapsTestParams; + +// Wrapper for convenience because Range is usually begin, size +static MemoryRegionInfo::RangeType make_range(lldb::addr_t begin, + lldb::addr_t end) { + MemoryRegionInfo::RangeType range(begin, 0); + range.SetRangeEnd(end); + return range; +} + +class LinuxProcMapsTestFixture + : public ::testing::TestWithParam { +protected: + Status error; + std::string err_str; + MemoryRegionInfos regions; + LinuxMapCallback callback; + + void SetUp() override { + callback = [this](llvm::Expected Info) { + if (Info) { + err_str.clear(); + regions.push_back(*Info); + return true; + } + + err_str = toString(Info.takeError()); + return false; + }; + } + + void check_regions(LinuxProcMapsTestParams params) { + EXPECT_THAT(std::get<1>(params), testing::ContainerEq(regions)); + ASSERT_EQ(std::get<2>(params), err_str); + } +}; + +TEST_P(LinuxProcMapsTestFixture, ParseMapRegions) { + auto params = GetParam(); + ParseLinuxMapRegions(std::get<0>(params), callback); + check_regions(params); +} + +// Note: ConstString("") != ConstString(nullptr) +// When a region has no name, it will have the latter in the MemoryRegionInfo +INSTANTIATE_TEST_CASE_P( + ProcMapTests, LinuxProcMapsTestFixture, + ::testing::Values( + // Nothing in nothing out + std::make_tuple("", MemoryRegionInfos{}, ""), + // Various formatting error conditions + std::make_tuple("55a4512f7000/55a451b68000 rw-p 00000000 00:00 0", + MemoryRegionInfos{}, + "malformed /proc/{pid}/maps entry, missing dash " + "between address range"), + std::make_tuple("0-0 rw", MemoryRegionInfos{}, + "malformed /proc/{pid}/maps entry, missing some " + "portion of permissions"), + std::make_tuple("0-0 z--p 00000000 00:00 0", MemoryRegionInfos{}, + "unexpected /proc/{pid}/maps read permission char"), + std::make_tuple("0-0 rz-p 00000000 00:00 0", MemoryRegionInfos{}, + "unexpected /proc/{pid}/maps write permission char"), + std::make_tuple("0-0 rwzp 00000000 00:00 0", MemoryRegionInfos{}, + "unexpected /proc/{pid}/maps exec permission char"), + // Stops at first parsing error + std::make_tuple( + "0-1 rw-p 00000000 00:00 0 [abc]\n" + "0-0 rwzp 00000000 00:00 0\n" + "2-3 r-xp 00000000 00:00 0 [def]\n", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0, 1), MemoryRegionInfo::eYes, + MemoryRegionInfo::eYes, MemoryRegionInfo::eNo, + MemoryRegionInfo::eYes, ConstString("[abc]"), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + }, + "unexpected /proc/{pid}/maps exec permission char"), + // Single entry + std::make_tuple( + "55a4512f7000-55a451b68000 rw-p 00000000 00:00 0 [heap]", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0x55a4512f7000, 0x55a451b68000), + MemoryRegionInfo::eYes, MemoryRegionInfo::eYes, + MemoryRegionInfo::eNo, MemoryRegionInfo::eYes, + ConstString("[heap]"), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + }, + ""), + // Multiple entries + std::make_tuple( + "7fc090021000-7fc094000000 ---p 00000000 00:00 0\n" + "ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0 " + "[vsyscall]", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0x7fc090021000, 0x7fc094000000), + MemoryRegionInfo::eNo, MemoryRegionInfo::eNo, + MemoryRegionInfo::eNo, MemoryRegionInfo::eYes, + ConstString(nullptr), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + MemoryRegionInfo( + make_range(0xffffffffff600000, 0xffffffffff601000), + MemoryRegionInfo::eYes, MemoryRegionInfo::eNo, + MemoryRegionInfo::eYes, MemoryRegionInfo::eYes, + ConstString("[vsyscall]"), MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + }, + "")), ); + +class LinuxProcSMapsTestFixture : public LinuxProcMapsTestFixture {}; + +INSTANTIATE_TEST_CASE_P( + ProcSMapTests, LinuxProcSMapsTestFixture, + ::testing::Values( + // Nothing in nothing out + std::make_tuple("", MemoryRegionInfos{}, ""), + // Uses the same parsing for first line, so same errors but referring to + // smaps + std::make_tuple("0/0 rw-p 00000000 00:00 0", MemoryRegionInfos{}, + "malformed /proc/{pid}/smaps entry, missing dash " + "between address range"), + // Stop parsing at first error + std::make_tuple( + "1111-2222 rw-p 00000000 00:00 0 [foo]\n" + "0/0 rw-p 00000000 00:00 0", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0x1111, 0x2222), + MemoryRegionInfo::eYes, MemoryRegionInfo::eYes, + MemoryRegionInfo::eNo, MemoryRegionInfo::eYes, + ConstString("[foo]"), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + }, + "malformed /proc/{pid}/smaps entry, missing dash between address " + "range"), + // Property line without a region is an error + std::make_tuple("Referenced: 2188 kB\n" + "1111-2222 rw-p 00000000 00:00 0 [foo]\n" + "3333-4444 rw-p 00000000 00:00 0 [bar]\n", + MemoryRegionInfos{}, + "Found a property line without a corresponding mapping " + "in /proc/{pid}/smaps"), + // Single region parses, has no flags + std::make_tuple( + "1111-2222 rw-p 00000000 00:00 0 [foo]", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0x1111, 0x2222), + MemoryRegionInfo::eYes, MemoryRegionInfo::eYes, + MemoryRegionInfo::eNo, MemoryRegionInfo::eYes, + ConstString("[foo]"), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + }, + ""), + // Single region with flags, other lines ignored + std::make_tuple( + "1111-2222 rw-p 00000000 00:00 0 [foo]\n" + "Referenced: 2188 kB\n" + "AnonHugePages: 0 kB\n" + "VmFlags: mt", + MemoryRegionInfos{ + MemoryRegionInfo( + make_range(0x1111, 0x2222), MemoryRegionInfo::eYes, + MemoryRegionInfo::eYes, MemoryRegionInfo::eNo, + MemoryRegionInfo::eYes, ConstString("[foo]"), + MemoryRegionInfo::eDontKnow, 0, MemoryRegionInfo::eYes), + }, + ""), + // Whitespace ignored + std::make_tuple( + "0-0 rw-p 00000000 00:00 0\n" + "VmFlags: mt ", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0, 0), MemoryRegionInfo::eYes, + MemoryRegionInfo::eYes, MemoryRegionInfo::eNo, + MemoryRegionInfo::eYes, ConstString(nullptr), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eYes), + }, + ""), + // VmFlags line means it has flag info, but nothing is set + std::make_tuple( + "0-0 rw-p 00000000 00:00 0\n" + "VmFlags: ", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0, 0), MemoryRegionInfo::eYes, + MemoryRegionInfo::eYes, MemoryRegionInfo::eNo, + MemoryRegionInfo::eYes, ConstString(nullptr), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eNo), + }, + ""), + // Handle some pages not having a flags line + std::make_tuple( + "1111-2222 rw-p 00000000 00:00 0 [foo]\n" + "Referenced: 2188 kB\n" + "AnonHugePages: 0 kB\n" + "3333-4444 r-xp 00000000 00:00 0 [bar]\n" + "VmFlags: mt", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0x1111, 0x2222), + MemoryRegionInfo::eYes, MemoryRegionInfo::eYes, + MemoryRegionInfo::eNo, MemoryRegionInfo::eYes, + ConstString("[foo]"), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + MemoryRegionInfo( + make_range(0x3333, 0x4444), MemoryRegionInfo::eYes, + MemoryRegionInfo::eNo, MemoryRegionInfo::eYes, + MemoryRegionInfo::eYes, ConstString("[bar]"), + MemoryRegionInfo::eDontKnow, 0, MemoryRegionInfo::eYes), + }, + ""), + // Handle no pages having a flags line (older kernels) + std::make_tuple( + "1111-2222 rw-p 00000000 00:00 0\n" + "Referenced: 2188 kB\n" + "AnonHugePages: 0 kB\n" + "3333-4444 r-xp 00000000 00:00 0\n" + "KernelPageSize: 4 kB\n" + "MMUPageSize: 4 kB\n", + MemoryRegionInfos{ + MemoryRegionInfo(make_range(0x1111, 0x2222), + MemoryRegionInfo::eYes, MemoryRegionInfo::eYes, + MemoryRegionInfo::eNo, MemoryRegionInfo::eYes, + ConstString(nullptr), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + MemoryRegionInfo(make_range(0x3333, 0x4444), + MemoryRegionInfo::eYes, MemoryRegionInfo::eNo, + MemoryRegionInfo::eYes, MemoryRegionInfo::eYes, + ConstString(nullptr), + MemoryRegionInfo::eDontKnow, 0, + MemoryRegionInfo::eDontKnow), + }, + "")), ); + +TEST_P(LinuxProcSMapsTestFixture, ParseSMapRegions) { + auto params = GetParam(); + ParseLinuxSMapRegions(std::get<0>(params), callback); + check_regions(params); +} diff --git a/lldb/unittests/Process/gdb-remote/GDBRemoteCommunicationClientTest.cpp b/lldb/unittests/Process/gdb-remote/GDBRemoteCommunicationClientTest.cpp index adfead6aed98bc..2cca197f63a1ea 100644 --- a/lldb/unittests/Process/gdb-remote/GDBRemoteCommunicationClientTest.cpp +++ b/lldb/unittests/Process/gdb-remote/GDBRemoteCommunicationClientTest.cpp @@ -343,6 +343,25 @@ TEST_F(GDBRemoteCommunicationClientTest, GetMemoryRegionInfo) { EXPECT_EQ(MemoryRegionInfo::eNo, region_info.GetWritable()); EXPECT_EQ(MemoryRegionInfo::eYes, region_info.GetExecutable()); EXPECT_EQ("/foo/bar.so", region_info.GetName().GetStringRef()); + EXPECT_EQ(MemoryRegionInfo::eDontKnow, region_info.GetMemoryTagged()); + + result = std::async(std::launch::async, [&] { + return client.GetMemoryRegionInfo(addr, region_info); + }); + + HandlePacket(server, "qMemoryRegionInfo:a000", + "start:a000;size:2000;flags:;"); + EXPECT_TRUE(result.get().Success()); + EXPECT_EQ(MemoryRegionInfo::eNo, region_info.GetMemoryTagged()); + + result = std::async(std::launch::async, [&] { + return client.GetMemoryRegionInfo(addr, region_info); + }); + + HandlePacket(server, "qMemoryRegionInfo:a000", + "start:a000;size:2000;flags: mt zz mt ;"); + EXPECT_TRUE(result.get().Success()); + EXPECT_EQ(MemoryRegionInfo::eYes, region_info.GetMemoryTagged()); } TEST_F(GDBRemoteCommunicationClientTest, GetMemoryRegionInfoInvalidResponse) { diff --git a/lldb/unittests/Process/minidump/MinidumpParserTest.cpp b/lldb/unittests/Process/minidump/MinidumpParserTest.cpp index 25d7e237bd2047..69046af283eba0 100644 --- a/lldb/unittests/Process/minidump/MinidumpParserTest.cpp +++ b/lldb/unittests/Process/minidump/MinidumpParserTest.cpp @@ -378,15 +378,15 @@ TEST_F(MinidumpParserTest, GetMemoryRegionInfo) { parser->BuildMemoryRegions(), testing::Pair(testing::ElementsAre( MemoryRegionInfo({0x0, 0x10000}, no, no, no, no, - ConstString(), unknown, 0), + ConstString(), unknown, 0, unknown), MemoryRegionInfo({0x10000, 0x21000}, yes, yes, no, yes, - ConstString(), unknown, 0), + ConstString(), unknown, 0, unknown), MemoryRegionInfo({0x40000, 0x1000}, yes, no, no, yes, - ConstString(), unknown, 0), + ConstString(), unknown, 0, unknown), MemoryRegionInfo({0x7ffe0000, 0x1000}, yes, no, no, yes, - ConstString(), unknown, 0), + ConstString(), unknown, 0, unknown), MemoryRegionInfo({0x7ffe1000, 0xf000}, no, no, no, yes, - ConstString(), unknown, 0)), + ConstString(), unknown, 0, unknown)), true)); } @@ -409,12 +409,13 @@ TEST_F(MinidumpParserTest, GetMemoryRegionInfoFromMemoryList) { EXPECT_THAT( parser->BuildMemoryRegions(), - testing::Pair(testing::ElementsAre( - MemoryRegionInfo({0x1000, 0x10}, yes, unknown, unknown, - yes, ConstString(), unknown, 0), - MemoryRegionInfo({0x2000, 0x20}, yes, unknown, unknown, - yes, ConstString(), unknown, 0)), - false)); + testing::Pair( + testing::ElementsAre( + MemoryRegionInfo({0x1000, 0x10}, yes, unknown, unknown, yes, + ConstString(), unknown, 0, unknown), + MemoryRegionInfo({0x2000, 0x20}, yes, unknown, unknown, yes, + ConstString(), unknown, 0, unknown)), + false)); } TEST_F(MinidumpParserTest, GetMemoryRegionInfoFromMemory64List) { @@ -424,12 +425,13 @@ TEST_F(MinidumpParserTest, GetMemoryRegionInfoFromMemory64List) { // we don't have a MemoryInfoListStream. EXPECT_THAT( parser->BuildMemoryRegions(), - testing::Pair(testing::ElementsAre( - MemoryRegionInfo({0x1000, 0x10}, yes, unknown, unknown, - yes, ConstString(), unknown, 0), - MemoryRegionInfo({0x2000, 0x20}, yes, unknown, unknown, - yes, ConstString(), unknown, 0)), - false)); + testing::Pair( + testing::ElementsAre( + MemoryRegionInfo({0x1000, 0x10}, yes, unknown, unknown, yes, + ConstString(), unknown, 0, unknown), + MemoryRegionInfo({0x2000, 0x20}, yes, unknown, unknown, yes, + ConstString(), unknown, 0, unknown)), + false)); } TEST_F(MinidumpParserTest, GetMemoryRegionInfoLinuxMaps) { @@ -453,22 +455,42 @@ TEST_F(MinidumpParserTest, GetMemoryRegionInfoLinuxMaps) { ConstString app_process("/system/bin/app_process"); ConstString linker("/system/bin/linker"); ConstString liblog("/system/lib/liblog.so"); - EXPECT_THAT( - parser->BuildMemoryRegions(), - testing::Pair(testing::ElementsAre( - MemoryRegionInfo({0x400d9000, 0x2000}, yes, no, yes, - yes, app_process, unknown, 0), - MemoryRegionInfo({0x400db000, 0x1000}, yes, no, no, yes, - app_process, unknown, 0), - MemoryRegionInfo({0x400dc000, 0x1000}, yes, yes, no, - yes, ConstString(), unknown, 0), - MemoryRegionInfo({0x400ec000, 0x1000}, yes, no, no, yes, - ConstString(), unknown, 0), - MemoryRegionInfo({0x400ee000, 0x1000}, yes, yes, no, - yes, linker, unknown, 0), - MemoryRegionInfo({0x400fc000, 0x1000}, yes, yes, yes, - yes, liblog, unknown, 0)), - true)); + EXPECT_THAT(parser->BuildMemoryRegions(), + testing::Pair( + testing::ElementsAre( + MemoryRegionInfo({0x400d9000, 0x2000}, yes, no, yes, yes, + app_process, unknown, 0, unknown), + MemoryRegionInfo({0x400db000, 0x1000}, yes, no, no, yes, + app_process, unknown, 0, unknown), + MemoryRegionInfo({0x400dc000, 0x1000}, yes, yes, no, yes, + ConstString(), unknown, 0, unknown), + MemoryRegionInfo({0x400ec000, 0x1000}, yes, no, no, yes, + ConstString(), unknown, 0, unknown), + MemoryRegionInfo({0x400ee000, 0x1000}, yes, yes, no, yes, + linker, unknown, 0, unknown), + MemoryRegionInfo({0x400fc000, 0x1000}, yes, yes, yes, yes, + liblog, unknown, 0, unknown)), + true)); +} + +TEST_F(MinidumpParserTest, GetMemoryRegionInfoLinuxMapsError) { + ASSERT_THAT_ERROR(SetUpFromYaml(R"( +--- !minidump +Streams: + - Type: LinuxMaps + Text: | + 400d9000-400db000 r?xp 00000000 b3:04 227 + 400fc000-400fd000 rwxp 00001000 b3:04 1096 +... +)"), + llvm::Succeeded()); + // Test that when a /proc/maps region fails to parse + // we handle the error and continue with the rest. + EXPECT_THAT(parser->BuildMemoryRegions(), + testing::Pair(testing::ElementsAre(MemoryRegionInfo( + {0x400fc000, 0x1000}, yes, yes, yes, yes, + ConstString(nullptr), unknown, 0, unknown)), + true)); } // Windows Minidump tests diff --git a/llvm/lib/Target/VE/VEISelLowering.cpp b/llvm/lib/Target/VE/VEISelLowering.cpp index b95229c94f6698..c41d0a416eaae2 100644 --- a/llvm/lib/Target/VE/VEISelLowering.cpp +++ b/llvm/lib/Target/VE/VEISelLowering.cpp @@ -1654,3 +1654,15 @@ VETargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT); } + +//===----------------------------------------------------------------------===// +// VE Target Optimization Support +//===----------------------------------------------------------------------===// + +unsigned VETargetLowering::getMinimumJumpTableEntries() const { + // Specify 8 for PIC model to relieve the impact of PIC load instructions. + if (isJumpTableRelative()) + return 8; + + return TargetLowering::getMinimumJumpTableEntries(); +} diff --git a/llvm/lib/Target/VE/VEISelLowering.h b/llvm/lib/Target/VE/VEISelLowering.h index f42aba40d6cd10..e12bef882d8ada 100644 --- a/llvm/lib/Target/VE/VEISelLowering.h +++ b/llvm/lib/Target/VE/VEISelLowering.h @@ -151,6 +151,9 @@ class VETargetLowering : public TargetLowering { /// Target Optimization { + // Return lower limit for number of blocks in a jump table. + unsigned getMinimumJumpTableEntries() const override; + // SX-Aurora VE's s/udiv is 5-9 times slower than multiply. bool isIntDivCheap(EVT, AttributeList) const override { return false; } // VE doesn't have rem. diff --git a/llvm/test/CodeGen/VE/Scalar/br_jt.ll b/llvm/test/CodeGen/VE/Scalar/br_jt.ll index a7218965c467f8..d84e830299ffca 100644 --- a/llvm/test/CodeGen/VE/Scalar/br_jt.ll +++ b/llvm/test/CodeGen/VE/Scalar/br_jt.ll @@ -2,23 +2,370 @@ ; RUN: llc < %s -mtriple=ve -relocation-model=pic \ ; RUN: | FileCheck %s -check-prefix=PIC +@switch.table.br_jt4 = private unnamed_addr constant [4 x i32] [i32 3, i32 0, i32 4, i32 7], align 4 +@switch.table.br_jt7 = private unnamed_addr constant [9 x i32] [i32 3, i32 0, i32 4, i32 7, i32 3, i32 3, i32 5, i32 11, i32 10], align 4 +@switch.table.br_jt8 = private unnamed_addr constant [9 x i32] [i32 3, i32 0, i32 4, i32 7, i32 3, i32 1, i32 5, i32 11, i32 10], align 4 + +; Function Attrs: norecurse nounwind readnone +define signext i32 @br_jt3(i32 signext %0) { +; CHECK-LABEL: br_jt3: +; CHECK: # %bb.0: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: breq.w 1, %s0, .LBB{{[0-9]+}}_1 +; CHECK-NEXT: # %bb.2: +; CHECK-NEXT: breq.w 4, %s0, .LBB{{[0-9]+}}_5 +; CHECK-NEXT: # %bb.3: +; CHECK-NEXT: brne.w 2, %s0, .LBB{{[0-9]+}}_6 +; CHECK-NEXT: # %bb.4: +; CHECK-NEXT: or %s0, 0, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_1: +; CHECK-NEXT: or %s0, 3, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_5: +; CHECK-NEXT: or %s0, 7, (0)1 +; CHECK-NEXT: .LBB{{[0-9]+}}_6: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; +; PIC-LABEL: br_jt3: +; PIC: # %bb.0: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: breq.w 1, %s0, .LBB0_1 +; PIC-NEXT: # %bb.2: +; PIC-NEXT: breq.w 4, %s0, .LBB0_5 +; PIC-NEXT: # %bb.3: +; PIC-NEXT: brne.w 2, %s0, .LBB0_6 +; PIC-NEXT: # %bb.4: +; PIC-NEXT: or %s0, 0, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB0_1: +; PIC-NEXT: or %s0, 3, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB0_5: +; PIC-NEXT: or %s0, 7, (0)1 +; PIC-NEXT: .LBB0_6: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) + switch i32 %0, label %4 [ + i32 1, label %5 + i32 2, label %2 + i32 4, label %3 + ] + +2: ; preds = %1 + br label %5 + +3: ; preds = %1 + br label %5 + +4: ; preds = %1 + br label %5 + +5: ; preds = %1, %4, %3, %2 + %6 = phi i32 [ %0, %4 ], [ 7, %3 ], [ 0, %2 ], [ 3, %1 ] + ret i32 %6 +} + ; Function Attrs: norecurse nounwind readnone -define signext i32 @br_jt(i32 signext %0) { -; CHECK-LABEL: br_jt: +define signext i32 @br_jt4(i32 signext %0) { +; CHECK-LABEL: br_jt4: ; CHECK: # %bb.0: ; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 ; CHECK-NEXT: adds.w.sx %s1, -1, %s0 ; CHECK-NEXT: cmpu.w %s2, 3, %s1 -; CHECK-NEXT: brgt.w 0, %s2, .LBB{{[0-9]+}}_5 +; CHECK-NEXT: brgt.w 0, %s2, .LBB{{[0-9]+}}_2 ; CHECK-NEXT: # %bb.1: -; CHECK-NEXT: adds.w.zx %s0, %s1, (0)1 -; CHECK-NEXT: sll %s0, %s0, 3 -; CHECK-NEXT: lea %s1, .LJTI0_0@lo +; CHECK-NEXT: adds.w.sx %s0, %s1, (0)1 +; CHECK-NEXT: sll %s0, %s0, 2 +; CHECK-NEXT: lea %s1, .Lswitch.table.br_jt4@lo +; CHECK-NEXT: and %s1, %s1, (32)0 +; CHECK-NEXT: lea.sl %s1, .Lswitch.table.br_jt4@hi(, %s1) +; CHECK-NEXT: ldl.sx %s0, (%s0, %s1) +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_2: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; +; PIC-LABEL: br_jt4: +; PIC: .LBB{{[0-9]+}}_5: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: adds.w.sx %s1, -1, %s0 +; PIC-NEXT: cmpu.w %s2, 3, %s1 +; PIC-NEXT: lea %s15, _GLOBAL_OFFSET_TABLE_@pc_lo(-24) +; PIC-NEXT: and %s15, %s15, (32)0 +; PIC-NEXT: sic %s16 +; PIC-NEXT: lea.sl %s15, _GLOBAL_OFFSET_TABLE_@pc_hi(%s16, %s15) +; PIC-NEXT: brgt.w 0, %s2, .LBB1_2 +; PIC-NEXT: # %bb.1: +; PIC-NEXT: adds.w.sx %s0, %s1, (0)1 +; PIC-NEXT: sll %s0, %s0, 2 +; PIC-NEXT: lea %s1, .Lswitch.table.br_jt4@gotoff_lo +; PIC-NEXT: and %s1, %s1, (32)0 +; PIC-NEXT: lea.sl %s1, .Lswitch.table.br_jt4@gotoff_hi(%s1, %s15) +; PIC-NEXT: ldl.sx %s0, (%s0, %s1) +; PIC-NEXT: br.l.t .LBB1_3 +; PIC-NEXT: .LBB1_2: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: .LBB1_3: +; PIC-NEXT: or %s11, 0, %s9 + %2 = add i32 %0, -1 + %3 = icmp ult i32 %2, 4 + br i1 %3, label %4, label %8 + +4: ; preds = %1 + %5 = sext i32 %2 to i64 + %6 = getelementptr inbounds [4 x i32], [4 x i32]* @switch.table.br_jt4, i64 0, i64 %5 + %7 = load i32, i32* %6, align 4 + ret i32 %7 + +8: ; preds = %1 + ret i32 %0 +} + +; Function Attrs: norecurse nounwind readnone +define signext i32 @br_jt7(i32 signext %0) { +; CHECK-LABEL: br_jt7: +; CHECK: # %bb.0: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: adds.w.sx %s1, -1, %s0 +; CHECK-NEXT: cmpu.w %s2, 8, %s1 +; CHECK-NEXT: brgt.w 0, %s2, .LBB{{[0-9]+}}_3 +; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: and %s2, %s1, (48)0 +; CHECK-NEXT: lea %s3, 463 +; CHECK-NEXT: and %s3, %s3, (32)0 +; CHECK-NEXT: srl %s2, %s3, %s2 +; CHECK-NEXT: and %s2, 1, %s2 +; CHECK-NEXT: brne.w 0, %s2, .LBB{{[0-9]+}}_2 +; CHECK-NEXT: .LBB{{[0-9]+}}_3: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_2: +; CHECK-NEXT: adds.w.sx %s0, %s1, (0)1 +; CHECK-NEXT: sll %s0, %s0, 2 +; CHECK-NEXT: lea %s1, .Lswitch.table.br_jt7@lo ; CHECK-NEXT: and %s1, %s1, (32)0 -; CHECK-NEXT: lea.sl %s1, .LJTI0_0@hi(, %s1) -; CHECK-NEXT: ld %s1, (%s1, %s0) +; CHECK-NEXT: lea.sl %s1, .Lswitch.table.br_jt7@hi(, %s1) +; CHECK-NEXT: ldl.sx %s0, (%s0, %s1) +; CHECK-NEXT: b.l.t (, %s10) +; +; PIC-LABEL: br_jt7: +; PIC: .LBB{{[0-9]+}}_6: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: adds.w.sx %s1, -1, %s0 +; PIC-NEXT: cmpu.w %s2, 8, %s1 +; PIC-NEXT: lea %s15, _GLOBAL_OFFSET_TABLE_@pc_lo(-24) +; PIC-NEXT: and %s15, %s15, (32)0 +; PIC-NEXT: sic %s16 +; PIC-NEXT: lea.sl %s15, _GLOBAL_OFFSET_TABLE_@pc_hi(%s16, %s15) +; PIC-NEXT: brgt.w 0, %s2, .LBB2_3 +; PIC-NEXT: # %bb.1: +; PIC-NEXT: and %s2, %s1, (48)0 +; PIC-NEXT: lea %s3, 463 +; PIC-NEXT: and %s3, %s3, (32)0 +; PIC-NEXT: srl %s2, %s3, %s2 +; PIC-NEXT: and %s2, 1, %s2 +; PIC-NEXT: brne.w 0, %s2, .LBB2_2 +; PIC-NEXT: .LBB2_3: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: br.l.t .LBB2_4 +; PIC-NEXT: .LBB2_2: +; PIC-NEXT: adds.w.sx %s0, %s1, (0)1 +; PIC-NEXT: sll %s0, %s0, 2 +; PIC-NEXT: lea %s1, .Lswitch.table.br_jt7@gotoff_lo +; PIC-NEXT: and %s1, %s1, (32)0 +; PIC-NEXT: lea.sl %s1, .Lswitch.table.br_jt7@gotoff_hi(%s1, %s15) +; PIC-NEXT: ldl.sx %s0, (%s0, %s1) +; PIC-NEXT: .LBB2_4: +; PIC-NEXT: or %s11, 0, %s9 + %2 = add i32 %0, -1 + %3 = icmp ult i32 %2, 9 + br i1 %3, label %4, label %13 + +4: ; preds = %1 + %5 = trunc i32 %2 to i16 + %6 = lshr i16 463, %5 + %7 = and i16 %6, 1 + %8 = icmp eq i16 %7, 0 + br i1 %8, label %13, label %9 + +9: ; preds = %4 + %10 = sext i32 %2 to i64 + %11 = getelementptr inbounds [9 x i32], [9 x i32]* @switch.table.br_jt7, i64 0, i64 %10 + %12 = load i32, i32* %11, align 4 + ret i32 %12 + +13: ; preds = %1, %4 + ret i32 %0 +} + +; Function Attrs: norecurse nounwind readnone +define signext i32 @br_jt8(i32 signext %0) { +; CHECK-LABEL: br_jt8: +; CHECK: # %bb.0: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: adds.w.sx %s1, -1, %s0 +; CHECK-NEXT: cmpu.w %s2, 8, %s1 +; CHECK-NEXT: brgt.w 0, %s2, .LBB{{[0-9]+}}_3 +; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: and %s2, %s1, (48)0 +; CHECK-NEXT: lea %s3, 495 +; CHECK-NEXT: and %s3, %s3, (32)0 +; CHECK-NEXT: srl %s2, %s3, %s2 +; CHECK-NEXT: and %s2, 1, %s2 +; CHECK-NEXT: brne.w 0, %s2, .LBB{{[0-9]+}}_2 +; CHECK-NEXT: .LBB{{[0-9]+}}_3: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_2: +; CHECK-NEXT: adds.w.sx %s0, %s1, (0)1 +; CHECK-NEXT: sll %s0, %s0, 2 +; CHECK-NEXT: lea %s1, .Lswitch.table.br_jt8@lo +; CHECK-NEXT: and %s1, %s1, (32)0 +; CHECK-NEXT: lea.sl %s1, .Lswitch.table.br_jt8@hi(, %s1) +; CHECK-NEXT: ldl.sx %s0, (%s0, %s1) +; CHECK-NEXT: b.l.t (, %s10) +; +; PIC-LABEL: br_jt8: +; PIC: .LBB{{[0-9]+}}_6: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: adds.w.sx %s1, -1, %s0 +; PIC-NEXT: cmpu.w %s2, 8, %s1 +; PIC-NEXT: lea %s15, _GLOBAL_OFFSET_TABLE_@pc_lo(-24) +; PIC-NEXT: and %s15, %s15, (32)0 +; PIC-NEXT: sic %s16 +; PIC-NEXT: lea.sl %s15, _GLOBAL_OFFSET_TABLE_@pc_hi(%s16, %s15) +; PIC-NEXT: brgt.w 0, %s2, .LBB3_3 +; PIC-NEXT: # %bb.1: +; PIC-NEXT: and %s2, %s1, (48)0 +; PIC-NEXT: lea %s3, 495 +; PIC-NEXT: and %s3, %s3, (32)0 +; PIC-NEXT: srl %s2, %s3, %s2 +; PIC-NEXT: and %s2, 1, %s2 +; PIC-NEXT: brne.w 0, %s2, .LBB3_2 +; PIC-NEXT: .LBB3_3: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: br.l.t .LBB3_4 +; PIC-NEXT: .LBB3_2: +; PIC-NEXT: adds.w.sx %s0, %s1, (0)1 +; PIC-NEXT: sll %s0, %s0, 2 +; PIC-NEXT: lea %s1, .Lswitch.table.br_jt8@gotoff_lo +; PIC-NEXT: and %s1, %s1, (32)0 +; PIC-NEXT: lea.sl %s1, .Lswitch.table.br_jt8@gotoff_hi(%s1, %s15) +; PIC-NEXT: ldl.sx %s0, (%s0, %s1) +; PIC-NEXT: .LBB3_4: +; PIC-NEXT: or %s11, 0, %s9 + %2 = add i32 %0, -1 + %3 = icmp ult i32 %2, 9 + br i1 %3, label %4, label %13 + +4: ; preds = %1 + %5 = trunc i32 %2 to i16 + %6 = lshr i16 495, %5 + %7 = and i16 %6, 1 + %8 = icmp eq i16 %7, 0 + br i1 %8, label %13, label %9 + +9: ; preds = %4 + %10 = sext i32 %2 to i64 + %11 = getelementptr inbounds [9 x i32], [9 x i32]* @switch.table.br_jt8, i64 0, i64 %10 + %12 = load i32, i32* %11, align 4 + ret i32 %12 + +13: ; preds = %1, %4 + ret i32 %0 +} + +; Function Attrs: norecurse nounwind readnone +define signext i32 @br_jt3_m(i32 signext %0, i32 signext %1) { +; CHECK-LABEL: br_jt3_m: +; CHECK: # %bb.0: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: breq.w 1, %s0, .LBB{{[0-9]+}}_1 +; CHECK-NEXT: # %bb.2: +; CHECK-NEXT: breq.w 4, %s0, .LBB{{[0-9]+}}_5 +; CHECK-NEXT: # %bb.3: +; CHECK-NEXT: brne.w 2, %s0, .LBB{{[0-9]+}}_6 +; CHECK-NEXT: # %bb.4: +; CHECK-NEXT: or %s0, 0, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_1: ; CHECK-NEXT: or %s0, 3, (0)1 -; CHECK-NEXT: b.l.t (, %s1) +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_5: +; CHECK-NEXT: adds.w.sx %s0, %s1, (0)1 +; CHECK-NEXT: adds.w.sx %s0, 3, %s0 +; CHECK-NEXT: .LBB{{[0-9]+}}_6: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; +; PIC-LABEL: br_jt3_m: +; PIC: # %bb.0: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: breq.w 1, %s0, .LBB4_1 +; PIC-NEXT: # %bb.2: +; PIC-NEXT: breq.w 4, %s0, .LBB4_5 +; PIC-NEXT: # %bb.3: +; PIC-NEXT: brne.w 2, %s0, .LBB4_6 +; PIC-NEXT: # %bb.4: +; PIC-NEXT: or %s0, 0, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB4_1: +; PIC-NEXT: or %s0, 3, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB4_5: +; PIC-NEXT: adds.w.sx %s0, %s1, (0)1 +; PIC-NEXT: adds.w.sx %s0, 3, %s0 +; PIC-NEXT: .LBB4_6: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) + switch i32 %0, label %6 [ + i32 1, label %7 + i32 2, label %3 + i32 4, label %4 + ] + +3: ; preds = %2 + br label %7 + +4: ; preds = %2 + %5 = add nsw i32 %1, 3 + br label %7 + +6: ; preds = %2 + br label %7 + +7: ; preds = %2, %6, %4, %3 + %8 = phi i32 [ %0, %6 ], [ %5, %4 ], [ 0, %3 ], [ 3, %2 ] + ret i32 %8 +} + +; Function Attrs: norecurse nounwind readnone +define signext i32 @br_jt4_m(i32 signext %0, i32 signext %1) { +; CHECK-LABEL: br_jt4_m: +; CHECK: # %bb.0: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: adds.w.sx %s2, -1, %s0 +; CHECK-NEXT: cmpu.w %s3, 3, %s2 +; CHECK-NEXT: brgt.w 0, %s3, .LBB{{[0-9]+}}_5 +; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: adds.w.zx %s0, %s2, (0)1 +; CHECK-NEXT: sll %s0, %s0, 3 +; CHECK-NEXT: lea %s2, .LJTI5_0@lo +; CHECK-NEXT: and %s2, %s2, (32)0 +; CHECK-NEXT: lea.sl %s2, .LJTI5_0@hi(, %s2) +; CHECK-NEXT: ld %s2, (%s2, %s0) +; CHECK-NEXT: or %s0, 3, (0)1 +; CHECK-NEXT: b.l.t (, %s2) ; CHECK-NEXT: .LBB{{[0-9]+}}_2: ; CHECK-NEXT: or %s0, 0, (0)1 ; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 @@ -28,89 +375,344 @@ define signext i32 @br_jt(i32 signext %0) { ; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 ; CHECK-NEXT: b.l.t (, %s10) ; CHECK-NEXT: .LBB{{[0-9]+}}_4: -; CHECK-NEXT: or %s0, 7, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s1, (0)1 +; CHECK-NEXT: adds.w.sx %s0, 3, %s0 ; CHECK-NEXT: .LBB{{[0-9]+}}_5: ; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 ; CHECK-NEXT: b.l.t (, %s10) ; -; PIC-LABEL: br_jt: +; PIC-LABEL: br_jt4_m: ; PIC: # %bb.0: -; PIC-NEXT: st %s9, (, %s11) -; PIC-NEXT: st %s10, 8(, %s11) -; PIC-NEXT: st %s15, 24(, %s11) -; PIC-NEXT: st %s16, 32(, %s11) -; PIC-NEXT: or %s9, 0, %s11 -; PIC-NEXT: lea %s13, -176 -; PIC-NEXT: and %s13, %s13, (32)0 -; PIC-NEXT: lea.sl %s11, -1(%s13, %s11) -; PIC-NEXT: brge.l %s11, %s8, .LBB0_7 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: brlt.w 2, %s0, .LBB5_4 +; PIC-NEXT: # %bb.1: +; PIC-NEXT: breq.w 1, %s0, .LBB5_8 +; PIC-NEXT: # %bb.2: +; PIC-NEXT: brne.w 2, %s0, .LBB5_7 +; PIC-NEXT: # %bb.3: +; PIC-NEXT: or %s0, 0, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB5_4: +; PIC-NEXT: breq.w 3, %s0, .LBB5_9 +; PIC-NEXT: # %bb.5: +; PIC-NEXT: brne.w 4, %s0, .LBB5_7 ; PIC-NEXT: # %bb.6: -; PIC-NEXT: ld %s61, 24(, %s14) -; PIC-NEXT: or %s62, 0, %s0 -; PIC-NEXT: lea %s63, 315 -; PIC-NEXT: shm.l %s63, (%s61) -; PIC-NEXT: shm.l %s8, 8(%s61) -; PIC-NEXT: shm.l %s11, 16(%s61) -; PIC-NEXT: monc -; PIC-NEXT: or %s0, 0, %s62 -; PIC-NEXT: .LBB0_7: +; PIC-NEXT: adds.w.sx %s0, %s1, (0)1 +; PIC-NEXT: adds.w.sx %s0, 3, %s0 +; PIC-NEXT: .LBB5_7: ; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 -; PIC-NEXT: adds.w.sx %s1, -1, %s0 -; PIC-NEXT: cmpu.w %s2, 3, %s1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB5_8: +; PIC-NEXT: or %s0, 3, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB5_9: +; PIC-NEXT: or %s0, 4, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) + switch i32 %0, label %7 [ + i32 1, label %8 + i32 2, label %3 + i32 3, label %4 + i32 4, label %5 + ] + +3: ; preds = %2 + br label %8 + +4: ; preds = %2 + br label %8 + +5: ; preds = %2 + %6 = add nsw i32 %1, 3 + br label %8 + +7: ; preds = %2 + br label %8 + +8: ; preds = %2, %7, %5, %4, %3 + %9 = phi i32 [ %0, %7 ], [ %6, %5 ], [ 4, %4 ], [ 0, %3 ], [ 3, %2 ] + ret i32 %9 +} + +; Function Attrs: norecurse nounwind readnone +define signext i32 @br_jt7_m(i32 signext %0, i32 signext %1) { +; CHECK-LABEL: br_jt7_m: +; CHECK: # %bb.0: +; CHECK-NEXT: adds.w.sx %s2, %s0, (0)1 +; CHECK-NEXT: adds.w.sx %s0, -1, %s2 +; CHECK-NEXT: cmpu.w %s3, 8, %s0 +; CHECK-NEXT: brgt.w 0, %s3, .LBB{{[0-9]+}}_8 +; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: adds.w.zx %s0, %s0, (0)1 +; CHECK-NEXT: sll %s0, %s0, 3 +; CHECK-NEXT: lea %s3, .LJTI6_0@lo +; CHECK-NEXT: and %s3, %s3, (32)0 +; CHECK-NEXT: lea.sl %s3, .LJTI6_0@hi(, %s3) +; CHECK-NEXT: ld %s3, (%s3, %s0) +; CHECK-NEXT: adds.w.sx %s1, %s1, (0)1 +; CHECK-NEXT: or %s0, 3, (0)1 +; CHECK-NEXT: b.l.t (, %s3) +; CHECK-NEXT: .LBB{{[0-9]+}}_2: +; CHECK-NEXT: or %s0, 0, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_3: +; CHECK-NEXT: or %s0, 4, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_4: +; CHECK-NEXT: adds.w.sx %s0, 3, %s1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_8: +; CHECK-NEXT: or %s0, 0, %s2 +; CHECK-NEXT: .LBB{{[0-9]+}}_9: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_7: +; CHECK-NEXT: or %s0, 11, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_6: +; CHECK-NEXT: or %s0, 10, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_5: +; CHECK-NEXT: adds.w.sx %s0, -2, %s1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; +; PIC-LABEL: br_jt7_m: +; PIC: # %bb.0: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: brge.w 3, %s0, .LBB6_1 +; PIC-NEXT: # %bb.6: +; PIC-NEXT: brlt.w 7, %s0, .LBB6_10 +; PIC-NEXT: # %bb.7: +; PIC-NEXT: adds.w.sx %s1, %s1, (0)1 +; PIC-NEXT: breq.w 4, %s0, .LBB6_14 +; PIC-NEXT: # %bb.8: +; PIC-NEXT: brne.w 7, %s0, .LBB6_16 +; PIC-NEXT: # %bb.9: +; PIC-NEXT: adds.w.sx %s0, -2, %s1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB6_1: +; PIC-NEXT: breq.w 1, %s0, .LBB6_2 +; PIC-NEXT: # %bb.3: +; PIC-NEXT: breq.w 2, %s0, .LBB6_13 +; PIC-NEXT: # %bb.4: +; PIC-NEXT: brne.w 3, %s0, .LBB6_16 +; PIC-NEXT: # %bb.5: +; PIC-NEXT: or %s0, 4, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB6_10: +; PIC-NEXT: breq.w 8, %s0, .LBB6_15 +; PIC-NEXT: # %bb.11: +; PIC-NEXT: brne.w 9, %s0, .LBB6_16 +; PIC-NEXT: # %bb.12: +; PIC-NEXT: or %s0, 10, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB6_14: +; PIC-NEXT: adds.w.sx %s0, 3, %s1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB6_2: +; PIC-NEXT: or %s0, 3, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB6_15: +; PIC-NEXT: or %s0, 11, (0)1 +; PIC-NEXT: .LBB6_16: +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) +; PIC-NEXT: .LBB6_13: +; PIC-NEXT: or %s0, 0, (0)1 +; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 +; PIC-NEXT: b.l.t (, %s10) + switch i32 %0, label %11 [ + i32 1, label %12 + i32 2, label %3 + i32 3, label %4 + i32 4, label %5 + i32 7, label %7 + i32 9, label %9 + i32 8, label %10 + ] + +3: ; preds = %2 + br label %12 + +4: ; preds = %2 + br label %12 + +5: ; preds = %2 + %6 = add nsw i32 %1, 3 + br label %12 + +7: ; preds = %2 + %8 = add nsw i32 %1, -2 + br label %12 + +9: ; preds = %2 + br label %12 + +10: ; preds = %2 + br label %12 + +11: ; preds = %2 + br label %12 + +12: ; preds = %2, %11, %10, %9, %7, %5, %4, %3 + %13 = phi i32 [ %0, %11 ], [ 11, %10 ], [ 10, %9 ], [ %8, %7 ], [ %6, %5 ], [ 4, %4 ], [ 0, %3 ], [ 3, %2 ] + ret i32 %13 +} + +; Function Attrs: norecurse nounwind readnone +define signext i32 @br_jt8_m(i32 signext %0, i32 signext %1) { +; CHECK-LABEL: br_jt8_m: +; CHECK: # %bb.0: +; CHECK-NEXT: adds.w.sx %s2, %s0, (0)1 +; CHECK-NEXT: adds.w.sx %s0, -1, %s2 +; CHECK-NEXT: cmpu.w %s3, 8, %s0 +; CHECK-NEXT: brgt.w 0, %s3, .LBB{{[0-9]+}}_9 +; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: adds.w.zx %s0, %s0, (0)1 +; CHECK-NEXT: sll %s0, %s0, 3 +; CHECK-NEXT: lea %s3, .LJTI7_0@lo +; CHECK-NEXT: and %s3, %s3, (32)0 +; CHECK-NEXT: lea.sl %s3, .LJTI7_0@hi(, %s3) +; CHECK-NEXT: ld %s3, (%s3, %s0) +; CHECK-NEXT: adds.w.sx %s1, %s1, (0)1 +; CHECK-NEXT: or %s0, 3, (0)1 +; CHECK-NEXT: b.l.t (, %s3) +; CHECK-NEXT: .LBB{{[0-9]+}}_2: +; CHECK-NEXT: or %s0, 0, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_3: +; CHECK-NEXT: or %s0, 4, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_4: +; CHECK-NEXT: adds.w.sx %s0, 3, %s1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_9: +; CHECK-NEXT: or %s0, 0, %s2 +; CHECK-NEXT: .LBB{{[0-9]+}}_10: +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_5: +; CHECK-NEXT: adds.w.sx %s0, -5, %s1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_6: +; CHECK-NEXT: adds.w.sx %s0, -2, %s1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_8: +; CHECK-NEXT: or %s0, 11, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; CHECK-NEXT: .LBB{{[0-9]+}}_7: +; CHECK-NEXT: or %s0, 10, (0)1 +; CHECK-NEXT: adds.w.sx %s0, %s0, (0)1 +; CHECK-NEXT: b.l.t (, %s10) +; +; PIC-LABEL: br_jt8_m: +; PIC: .LBB{{[0-9]+}}_12: +; PIC-NEXT: adds.w.sx %s2, %s0, (0)1 +; PIC-NEXT: adds.w.sx %s0, -1, %s2 +; PIC-NEXT: cmpu.w %s3, 8, %s0 ; PIC-NEXT: lea %s15, _GLOBAL_OFFSET_TABLE_@pc_lo(-24) ; PIC-NEXT: and %s15, %s15, (32)0 ; PIC-NEXT: sic %s16 ; PIC-NEXT: lea.sl %s15, _GLOBAL_OFFSET_TABLE_@pc_hi(%s16, %s15) -; PIC-NEXT: brgt.w 0, %s2, .LBB0_5 +; PIC-NEXT: brgt.w 0, %s3, .LBB7_9 ; PIC-NEXT: # %bb.1: -; PIC-NEXT: adds.w.zx %s0, %s1, (0)1 +; PIC-NEXT: adds.w.sx %s1, %s1, (0)1 +; PIC-NEXT: adds.w.zx %s0, %s0, (0)1 ; PIC-NEXT: sll %s0, %s0, 2 -; PIC-NEXT: lea %s1, .LJTI0_0@gotoff_lo -; PIC-NEXT: and %s1, %s1, (32)0 -; PIC-NEXT: lea.sl %s1, .LJTI0_0@gotoff_hi(%s1, %s15) -; PIC-NEXT: ldl.sx %s0, (%s1, %s0) -; PIC-NEXT: lea %s1, br_jt@gotoff_lo -; PIC-NEXT: and %s1, %s1, (32)0 -; PIC-NEXT: lea.sl %s1, br_jt@gotoff_hi(%s1, %s15) -; PIC-NEXT: adds.l %s1, %s0, %s1 +; PIC-NEXT: lea %s3, .LJTI7_0@gotoff_lo +; PIC-NEXT: and %s3, %s3, (32)0 +; PIC-NEXT: lea.sl %s3, .LJTI7_0@gotoff_hi(%s3, %s15) +; PIC-NEXT: ldl.sx %s0, (%s3, %s0) +; PIC-NEXT: lea %s3, br_jt8_m@gotoff_lo +; PIC-NEXT: and %s3, %s3, (32)0 +; PIC-NEXT: lea.sl %s3, br_jt8_m@gotoff_hi(%s3, %s15) +; PIC-NEXT: adds.l %s3, %s0, %s3 ; PIC-NEXT: or %s0, 3, (0)1 -; PIC-NEXT: b.l.t (, %s1) -; PIC-NEXT: .LBB0_2: +; PIC-NEXT: b.l.t (, %s3) +; PIC-NEXT: .LBB7_2: ; PIC-NEXT: or %s0, 0, (0)1 -; PIC-NEXT: br.l.t .LBB0_5 -; PIC-NEXT: .LBB0_3: +; PIC-NEXT: br.l.t .LBB7_10 +; PIC-NEXT: .LBB7_3: ; PIC-NEXT: or %s0, 4, (0)1 -; PIC-NEXT: br.l.t .LBB0_5 -; PIC-NEXT: .LBB0_4: -; PIC-NEXT: or %s0, 7, (0)1 -; PIC-NEXT: .LBB0_5: +; PIC-NEXT: br.l.t .LBB7_10 +; PIC-NEXT: .LBB7_4: +; PIC-NEXT: adds.w.sx %s0, 3, %s1 +; PIC-NEXT: br.l.t .LBB7_10 +; PIC-NEXT: .LBB7_9: +; PIC-NEXT: or %s0, 0, %s2 +; PIC-NEXT: br.l.t .LBB7_10 +; PIC-NEXT: .LBB7_5: +; PIC-NEXT: adds.w.sx %s0, -5, %s1 +; PIC-NEXT: br.l.t .LBB7_10 +; PIC-NEXT: .LBB7_6: +; PIC-NEXT: adds.w.sx %s0, -2, %s1 +; PIC-NEXT: br.l.t .LBB7_10 +; PIC-NEXT: .LBB7_8: +; PIC-NEXT: or %s0, 11, (0)1 +; PIC-NEXT: br.l.t .LBB7_10 +; PIC-NEXT: .LBB7_7: +; PIC-NEXT: or %s0, 10, (0)1 +; PIC-NEXT: .LBB7_10: ; PIC-NEXT: adds.w.sx %s0, %s0, (0)1 ; PIC-NEXT: or %s11, 0, %s9 -; PIC-NEXT: ld %s16, 32(, %s11) -; PIC-NEXT: ld %s15, 24(, %s11) -; PIC-NEXT: ld %s10, 8(, %s11) -; PIC-NEXT: ld %s9, (, %s11) -; PIC-NEXT: b.l.t (, %s10) - switch i32 %0, label %5 [ - i32 1, label %6 - i32 2, label %2 - i32 3, label %3 - i32 4, label %4 + switch i32 %0, label %13 [ + i32 1, label %14 + i32 2, label %3 + i32 3, label %4 + i32 4, label %5 + i32 6, label %7 + i32 7, label %9 + i32 9, label %11 + i32 8, label %12 ] -2: ; preds = %1 - br label %6 +3: ; preds = %2 + br label %14 -3: ; preds = %1 - br label %6 +4: ; preds = %2 + br label %14 -4: ; preds = %1 - br label %6 +5: ; preds = %2 + %6 = add nsw i32 %1, 3 + br label %14 -5: ; preds = %1 - br label %6 +7: ; preds = %2 + %8 = add nsw i32 %1, -5 + br label %14 -6: ; preds = %1, %5, %4, %3, %2 - %7 = phi i32 [ %0, %5 ], [ 7, %4 ], [ 4, %3 ], [ 0, %2 ], [ 3, %1 ] - ret i32 %7 +9: ; preds = %2 + %10 = add nsw i32 %1, -2 + br label %14 + +11: ; preds = %2 + br label %14 + +12: ; preds = %2 + br label %14 + +13: ; preds = %2 + br label %14 + +14: ; preds = %2, %13, %12, %11, %9, %7, %5, %4, %3 + %15 = phi i32 [ %0, %13 ], [ 11, %12 ], [ 10, %11 ], [ %10, %9 ], [ %8, %7 ], [ %6, %5 ], [ 4, %4 ], [ 0, %3 ], [ 3, %2 ] + ret i32 %15 } diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h index ad5a8aa0309862..d0664b08c0fbb6 100644 --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -53,6 +53,16 @@ class GroupType : public Type::TypeBase { using Base::Base; }; +// -------------------------------------------------------------------------- // +// Helper functions of Async dialect transformations. +// -------------------------------------------------------------------------- // + +/// Returns true if the type is reference counted. All async dialect types are +/// reference counted at runtime. +inline bool isRefCounted(Type type) { + return type.isa(); +} + } // namespace async } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td index e7a5e90298da9a..e33a9e286b7fa7 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -73,4 +73,8 @@ def Async_AnyValueType : DialectType; +def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType, + Async_TokenType, + Async_GroupType]>; + #endif // ASYNC_BASE_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index cc987856a28e36..80aeabf5f9043b 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -227,4 +227,62 @@ def Async_AwaitAllOp : Async_Op<"await_all", []> { let assemblyFormat = "$operand attr-dict"; } +//===----------------------------------------------------------------------===// +// Async Dialect Automatic Reference Counting Operations. +//===----------------------------------------------------------------------===// + +// All async values (values, tokens, groups) are reference counted at runtime +// and automatically destructed when reference count drops to 0. +// +// All values are semantically created with a reference count of +1 and it is +// the responsibility of the last async value user to drop reference count. +// +// Async values created when: +// 1. Operation returns async result (e.g. the result of an `async.execute`). +// 2. Async value passed in as a block argument. +// +// It is the responsiblity of the async value user to extend the lifetime by +// adding a +1 reference, if the reference counted value captured by the +// asynchronously executed region (`async.execute` operation), and drop it after +// the last nested use. +// +// Reference counting operations can be added to the IR using automatic +// reference count pass, that relies on liveness analysis to find the last uses +// of all reference counted values and automatically inserts +// `drop_ref` operations. +// +// See `AsyncRefCountingPass` documentation for the implementation details. + +def Async_AddRefOp : Async_Op<"add_ref"> { + let summary = "adds a reference to async value"; + let description = [{ + The `async.add_ref` operation adds a reference(s) to async value (token, + value or group). + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Confined:$count); + let results = (outs ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) + }]; +} + +def Async_DropRefOp : Async_Op<"drop_ref"> { + let summary = "drops a reference to async value"; + let description = [{ + The `async.drop_ref` operation drops a reference(s) to async value (token, + value or group). + }]; + + let arguments = (ins Async_AnyAsyncType:$operand, + Confined:$count); + let results = (outs ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) + }]; +} + #endif // ASYNC_OPS diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h index d5a8a82dab49b9..9716bde765935a 100644 --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -19,6 +19,10 @@ namespace mlir { std::unique_ptr> createAsyncParallelForPass(); +std::unique_ptr> createAsyncRefCountingPass(); + +std::unique_ptr> createAsyncRefCountingOptimizationPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td index 51fd4e32c78e42..140a3b41162a1c 100644 --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -24,4 +24,18 @@ def AsyncParallelFor : FunctionPass<"async-parallel-for"> { let dependentDialects = ["async::AsyncDialect", "scf::SCFDialect"]; } +def AsyncRefCounting : FunctionPass<"async-ref-counting"> { + let summary = "Automatic reference counting for Async dialect data types"; + let constructor = "mlir::createAsyncRefCountingPass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + +def AsyncRefCountingOptimization : + FunctionPass<"async-ref-counting-optimization"> { + let summary = "Optimize automatic reference counting operations for the" + "Async dialect by removing redundant operations"; + let constructor = "mlir::createAsyncRefCountingOptimizationPass()"; + let dependentDialects = ["async::AsyncDialect"]; +} + #endif // MLIR_DIALECT_ASYNC_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 87ff2a97d93f84..8d531a1e343a08 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -37,6 +37,14 @@ struct TiledLinalgOp { SmallVector tensorResults; }; +struct TiledAndFusedLinalgOps { + LinalgOp op; + SmallVector fusedProducers; + SmallVector originalProducers; + SmallVector fusedLoops; + SmallVector unfusedLoops; +}; + /// Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, @@ -65,11 +73,14 @@ void populateLinalgBufferizePatterns(MLIRContext *context, Optional tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options); -/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This -/// proceeds as follows: -/// - Find outer parallel loops in these ops that can be fused. -/// - Tile fusable outer parallel loops of the last operation in the sequence. -/// - Fuse the remaining operations with the tiled operation +/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in +/// three steps +/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile +/// + fuse loops). +/// - Tile just these loops of the consumer (root operation) and fuse with +/// the producer. +/// - Tile again the tiled consumer operation produced above to do rest of +/// the tiling specified by the `tilingOptions`. /// /// For example, consider the sequence of matmul below /// @@ -96,39 +107,36 @@ Optional tileLinalgOp(OpBuilder &b, LinalgOp op, /// : memref<256x32xf32> to memref<16x32xf32, #map0> /// %3 = subview %arg1[0, 0] [32, 32] [1, 1] /// : memref<32x32xf32> to memref<32x32xf32, #map1> -/// %4 = subview %arg3[0, 0] [32, 32] [1, 1] -/// : memref<32x32xf32> to memref<32x32xf32, #map1> /// linalg.matmul /// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) /// outs(%0 : memref<16x32xf32, #map0>) -/// linalg.matmul -/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) -/// outs(%1 : memref<16x8xf32, #map0>) +/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) { +/// scf.for %arg7 = %c0 to %c32 step %c4 { +/// %4 = subview %0[0, %arg7] [16, 4] [1, 1] +/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0> +/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1] +/// : memref<32x32xf32> to memref<4x8xf32, #map0> +/// %6 = subview %1[0, %arg6] [16, 8] [1, 1] +/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0> +/// linalg.matmul +/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) +/// outs(%6 : memref<16x8xf32, #map0>) +/// } +/// scf.yield +/// } +/// scf.yield /// } /// -/// `tilingOptions` are used to tile the corresponding operation in `ops` (the -/// size of the former should be same as size of the latter. Based on how -/// tile+fuse is implemented, the fused loops are generated based on the last -/// operation in the sequence. For example, the tile sizes for the fused loops -/// is obtained from `tilingOptions.back()`. The following tiling options are -/// handled differently in tile+fuse (compared to tile only) +/// The following tiling options are handled differently in tile+fuse (compared +/// to tile only) /// - Interchange of the tiling loops is not supported right now. -/// - Only the fused loops are distributed. -struct TiledAndFusedLinalgOps { - /// Operation obtained by tiling the last operation in sequence of `ops` - /// passed to `tileAndFuseLinalgOps`. - LinalgOp op; - /// The dimension of the loops that are fused. - std::set fusedLoopDims; - /// The generated fused operations (created within the fused loops). - SmallVector fusedProducers; - /// The fused loop generated. - SmallVector fusedLoops; -}; +/// - Distribution is only done for the tile+fuse loops. The tiled loops +/// generated by the second tiling is not distributed. Optional -tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, +tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions); + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions); /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. /// This is an in-place transformation controlled by `interchangeVector`. @@ -234,20 +242,6 @@ struct LinalgPromotionOptions { } }; -/// Creates a new buffer using the `allocationFn` provided. The size of this -/// buffer is the smallest constant bounding size along each dimension that can -/// be computed for the size of the result of `subView`. Returns the allocated -/// buffer as `fullLocalView` and the view that matches the size of the result -/// of subview operation as `partialLocalView`. -struct PromotionInfo { - Value fullLocalView; - Value partialLocalView; -}; -Optional -promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView, - AllocBufferCallbackFn allocationFn, - OperationFolder *folder = nullptr); - /// Promotes the `subViews` into a new buffer allocated at the insertion point /// `b`. Promotion occurs in 3 steps: /// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary). diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 1eaf8b0e709c7d..f5669e383368c5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -17,7 +17,6 @@ #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" using mlir::edsc::intrinsics::AffineIndexedValue; @@ -83,13 +82,6 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer); -using FusableOpDependencesTy = llvm::MapVector< - Operation *, - SmallVector>; -FusableOpDependencesTy -findAllFusableDependences(ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph); - /// Fuses producer into consumer if the producer is structurally feasible and /// the fusion would not violate dependencies. /// Implements the fusion part of the "tileAndFuse on buffers" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 8512c933e4245f..1ad3df63c1c968 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2234,6 +2234,7 @@ def LoadOp : Std_Op<"load", operand_range getIndices() { return {operand_begin() + 1, operand_end()}; } }]; + let hasCanonicalizer = 1; let hasFolder = 1; let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; diff --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h index 12beffe9dd1cd0..26b0a236f0d3fa 100644 --- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h +++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h @@ -48,6 +48,18 @@ typedef struct AsyncGroup MLIR_AsyncGroup; using CoroHandle = void *; // coroutine handle using CoroResume = void (*)(void *); // coroutine resume function +// Async runtime uses reference counting to manage the lifetime of async values +// (values of async types like tokens, values and groups). +using RefCountedObjPtr = void *; + +// Adds references to reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void + mlirAsyncRuntimeAddRef(RefCountedObjPtr, int32_t); + +// Drops references from reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void + mlirAsyncRuntimeDropRef(RefCountedObjPtr, int32_t); + // Create a new `async.token` in not-ready state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncToken *mlirAsyncRuntimeCreateToken(); diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir index 0618771052833d..74c0556c4bd02a 100644 --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir index 79fa4c2e2c3c21..196ab89b59e050 100644 --- a/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir +++ b/mlir/integration_test/Dialect/Async/CPU/test-async-parallel-for-2d.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -async-parallel-for \ +// RUN: -async-ref-counting \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 0cbf3debd89429..b08f7e4c45b7c0 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -33,6 +33,8 @@ static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; // Async Runtime C API declaration. //===----------------------------------------------------------------------===// +static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; +static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; @@ -49,6 +51,12 @@ static constexpr const char *kAwaitAllAndExecute = namespace { // Async Runtime API function types. struct AsyncAPI { + static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { + auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); + auto count = IntegerType::get(32, ctx); + return FunctionType::get({ref, count}, {}, ctx); + } + static FunctionType createTokenFunctionType(MLIRContext *ctx) { return FunctionType::get({}, {TokenType::get(ctx)}, ctx); } @@ -113,6 +121,8 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) { }; MLIRContext *ctx = module.getContext(); + addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); + addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); @@ -121,7 +131,8 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) { addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx)); - addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); + addFuncDecl(kAwaitAllAndExecute, + AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); } //===----------------------------------------------------------------------===// @@ -588,6 +599,55 @@ class CallOpOpConversion : public ConversionPattern { }; } // namespace +//===----------------------------------------------------------------------===// +// Async reference counting ops lowering (`async.add_ref` and `async.drop_ref` +// to the corresponding API calls). +//===----------------------------------------------------------------------===// + +namespace { + +template +class RefCountingOpLowering : public ConversionPattern { +public: + explicit RefCountingOpLowering(MLIRContext *ctx, StringRef apiFunctionName) + : ConversionPattern(RefCountingOp::getOperationName(), 1, ctx), + apiFunctionName(apiFunctionName) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RefCountingOp refCountingOp = cast(op); + + auto count = rewriter.create( + op->getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(refCountingOp.count())); + + rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, + ValueRange({operands[0], count})); + + return success(); + } + +private: + StringRef apiFunctionName; +}; + +// async.drop_ref op lowering to mlirAsyncRuntimeDropRef function call. +class AddRefOpLowering : public RefCountingOpLowering { +public: + explicit AddRefOpLowering(MLIRContext *ctx) + : RefCountingOpLowering(ctx, kAddRef) {} +}; + +// async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. +class DropRefOpLowering : public RefCountingOpLowering { +public: + explicit DropRefOpLowering(MLIRContext *ctx) + : RefCountingOpLowering(ctx, kDropRef) {} +}; + +} // namespace + //===----------------------------------------------------------------------===// // async.create_group op lowering to mlirAsyncRuntimeCreateGroup function call. //===----------------------------------------------------------------------===// @@ -794,10 +854,12 @@ void ConvertAsyncToLLVMPass::runOnOperation() { populateFuncOpTypeConversionPattern(patterns, ctx, converter); patterns.insert(ctx); + patterns.insert(ctx); patterns.insert(ctx); patterns.insert(ctx, outlinedFunctions); ConversionTarget target(*ctx); + target.addLegalOp(); target.addLegalDialect(); target.addIllegalDialect(); target.addDynamicallyLegalOp( diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp new file mode 100644 index 00000000000000..ea1da590aeea5e --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp @@ -0,0 +1,324 @@ +//===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements automatic reference counting for Async dialect data +// types. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-ref-counting" + +namespace { + +class AsyncRefCountingPass : public AsyncRefCountingBase { +public: + AsyncRefCountingPass() = default; + void runOnFunction() override; + +private: + /// Adds an automatic reference counting to the `value`. + /// + /// All values are semantically created with a reference count of +1 and it is + /// the responsibility of the last async value user to drop reference count. + /// + /// Async values created when: + /// 1. Operation returns async result (e.g. the result of an + /// `async.execute`). + /// 2. Async value passed in as a block argument. + /// + /// To implement automatic reference counting, we must insert a +1 reference + /// before each `async.execute` operation using the value, and drop it after + /// the last use inside the async body region (we currently drop the reference + /// before the `async.yield` terminator). + /// + /// Automatic reference counting algorithm outline: + /// + /// 1. `ReturnLike` operations forward the reference counted values without + /// modifying the reference count. + /// + /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of + /// reference counted values ends, and insert `drop_ref` operations after + /// the last use of the value. + /// + /// 3. Insert `add_ref` before the `async.execute` operation capturing the + /// value, and pairing `drop_ref` before the async body region terminator, + /// to release the captured reference counted value when execution + /// completes. + /// + /// 4. If the reference counted value is passed only to some of the block + /// successors, insert `drop_ref` operations in the beginning of the blocks + /// that do not have reference counted value uses. + /// + /// + /// Example: + /// + /// %token = ... + /// async.execute { + /// async.await %token : !async.token // await #1 + /// async.yield + /// } + /// async.await %token : !async.token // await #2 + /// + /// Based on the liveness analysis await #2 is the last use of the %token, + /// however the execution of the async region can be delayed, and to guarantee + /// that the %token is still alive when await #1 executes we need to + /// explicitly extend its lifetime using `add_ref` operation. + /// + /// After automatic reference counting: + /// + /// %token = ... + /// + /// // Make sure that %token is alive inside async.execute. + /// async.add_ref %token {count = 1 : i32} : !async.token + /// + /// async.execute { + /// async.await %token : !async.token // await #1 + /// + /// // Drop the extra reference added to keep %token alive. + /// async.drop_ref %token {count = 1 : i32} : !async.token + /// + /// async.yied + /// } + /// async.await %token : !async.token // await #2 + /// + /// // Drop the reference after the last use of %token. + /// async.drop_ref %token {count = 1 : i32} : !async.token + /// + LogicalResult addAutomaticRefCounting(Value value); +}; + +} // namespace + +LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) { + MLIRContext *ctx = value.getContext(); + OpBuilder builder(ctx); + + // Set inserton point after the operation producing a value, or at the + // beginning of the block if the value defined by the block argument. + if (Operation *op = value.getDefiningOp()) + builder.setInsertionPointAfter(op); + else + builder.setInsertionPointToStart(value.getParentBlock()); + + Location loc = value.getLoc(); + auto i32 = IntegerType::get(32, ctx); + + // Drop the reference count immediately if the value has no uses. + if (value.getUses().empty()) { + builder.create(loc, value, IntegerAttr::get(i32, 1)); + return success(); + } + + // Use liveness analysis to find the placement of `drop_ref`operation. + auto liveness = getAnalysis(); + + // We analyse only the blocks of the region that defines the `value`, and do + // not check nested blocks attached to operations. + // + // By analyzing only the `definingRegion` CFG we potentially loose an + // opportunity to drop the reference count earlier and can extend the lifetime + // of reference counted value longer then it is really required. + // + // We also assume that all nested regions finish their execution before the + // completion of the owner operation. The only exception to this rule is + // `async.execute` operation, which is handled explicitly below. + Region *definingRegion = value.getParentRegion(); + + // ------------------------------------------------------------------------ // + // Find blocks where the `value` dies: the value is in `liveIn` set and not + // in the `liveOut` set. We place `drop_ref` immediately after the last use + // of the `value` in such regions. + // ------------------------------------------------------------------------ // + + // Last users of the `value` inside all blocks where the value dies. + llvm::SmallSet lastUsers; + + for (Block &block : definingRegion->getBlocks()) { + const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); + + // Value in live input set or was defined in the block. + bool liveIn = blockLiveness->isLiveIn(value) || + blockLiveness->getBlock() == value.getParentBlock(); + if (!liveIn) + continue; + + // Value is in the live out set. + bool liveOut = blockLiveness->isLiveOut(value); + if (liveOut) + continue; + + // We proved that `value` dies in the `block`. Now find the last use of the + // `value` inside the `block`. + + // Find any user of the `value` inside the block (including uses in nested + // regions attached to the operations in the block). + Operation *userInTheBlock = nullptr; + for (Operation *user : value.getUsers()) { + userInTheBlock = block.findAncestorOpInBlock(*user); + if (userInTheBlock) + break; + } + + // Values with zero users handled explicitly in the beginning, if the value + // is in live out set it must have at least one use in the block. + assert(userInTheBlock && "value must have a user in the block"); + + // Find the last user of the `value` in the block; + Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); + assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); + lastUsers.insert(lastUser); + } + + // Process all the last users of the `value` inside each block where the value + // dies. + for (Operation *lastUser : lastUsers) { + // Return like operations forward reference count. + if (lastUser->hasTrait()) + continue; + + // We can't currently handle other types of terminators. + if (lastUser->hasTrait()) + return lastUser->emitError() << "async reference counting can't handle " + "terminators that are not ReturnLike"; + + // Add a drop_ref immediately after the last user. + builder.setInsertionPointAfter(lastUser); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + // ------------------------------------------------------------------------ // + // Find blocks where the `value` is in `liveOut` set, however it is not in + // the `liveIn` set of all successors. If the `value` is not in the successor + // `liveIn` set, we add a `drop_ref` to the beginning of it. + // ------------------------------------------------------------------------ // + + // Successors that we'll need a `drop_ref` for the `value`. + llvm::SmallSet dropRefSuccessors; + + for (Block &block : definingRegion->getBlocks()) { + const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); + + // Skip the block if value is not in the `liveOut` set. + if (!blockLiveness->isLiveOut(value)) + continue; + + // Find successors that do not have `value` in the `liveIn` set. + for (Block *successor : block.getSuccessors()) { + const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); + + if (!succLiveness->isLiveIn(value)) + dropRefSuccessors.insert(successor); + } + } + + // Drop reference in all successor blocks that do not have the `value` in + // their `liveIn` set. + for (Block *dropRefSuccessor : dropRefSuccessors) { + builder.setInsertionPointToStart(dropRefSuccessor); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + // ------------------------------------------------------------------------ // + // Find all `async.execute` operation that take `value` as an operand + // (dependency token or async value), or capture implicitly by the nested + // region. Each `async.execute` operation will require `add_ref` operation + // to keep all captured values alive until it will finish its execution. + // ------------------------------------------------------------------------ // + + llvm::SmallSet executeOperations; + + auto trackAsyncExecute = [&](Operation *op) { + if (auto execute = dyn_cast(op)) + executeOperations.insert(execute); + }; + + for (Operation *user : value.getUsers()) { + // Follow parent operations up until the operation in the `definingRegion`. + while (user->getParentRegion() != definingRegion) { + trackAsyncExecute(user); + user = user->getParentOp(); + assert(user != nullptr && "value user lies outside of the value region"); + } + + // Don't forget to process the parent in the `definingRegion` (can be the + // original user operation itself). + trackAsyncExecute(user); + } + + // Process all `async.execute` operations capturing `value`. + for (ExecuteOp execute : executeOperations) { + // Add a reference before the execute operation to keep the reference + // counted alive before the async region completes execution. + builder.setInsertionPoint(execute.getOperation()); + builder.create(loc, value, IntegerAttr::get(i32, 1)); + + // Drop the reference inside the async region before completion. + OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody()); + executeBuilder.create(loc, value, IntegerAttr::get(i32, 1)); + } + + return success(); +} + +void AsyncRefCountingPass::runOnFunction() { + FuncOp func = getFunction(); + + // Check that we do not have explicit `add_ref` or `drop_ref` in the IR + // because otherwise automatic reference counting will produce incorrect + // results. + WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult { + if (isa(op)) + return op->emitError() << "explicit reference counting is not supported"; + return WalkResult::advance(); + }); + + if (refCountingWalk.wasInterrupted()) + signalPassFailure(); + + // Add reference counting to block arguments. + WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(addAutomaticRefCounting(arg))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) + signalPassFailure(); + + // Add reference counting to operation results. + WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(addAutomaticRefCounting(op->getResult(i)))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + signalPassFailure(); +} + +std::unique_ptr> mlir::createAsyncRefCountingPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp new file mode 100644 index 00000000000000..cbcb30c5276a66 --- /dev/null +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCountingOptimization.cpp @@ -0,0 +1,218 @@ +//===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Optimize Async dialect reference counting operations. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Async/Passes.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; +using namespace mlir::async; + +#define DEBUG_TYPE "async-ref-counting" + +namespace { + +class AsyncRefCountingOptimizationPass + : public AsyncRefCountingOptimizationBase< + AsyncRefCountingOptimizationPass> { +public: + AsyncRefCountingOptimizationPass() = default; + void runOnFunction() override; + +private: + LogicalResult optimizeReferenceCounting(Value value); +}; + +} // namespace + +LogicalResult +AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) { + Region *definingRegion = value.getParentRegion(); + + // Find all users of the `value` inside each block, including operations that + // do not use `value` directly, but have a direct use inside nested region(s). + // + // Example: + // + // ^bb1: + // %token = ... + // scf.if %cond { + // ^bb2: + // async.await %token : !async.token + // } + // + // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`). + // + // In addition to the operation that uses the `value` we also keep track if + // this user is an `async.execute` operation itself, or has `async.execute` + // operations in the nested regions that do use the `value`. + + struct UserInfo { + Operation *operation; + bool hasExecuteUser; + }; + + struct BlockUsersInfo { + llvm::SmallVector addRefs; + llvm::SmallVector dropRefs; + llvm::SmallVector users; + }; + + llvm::DenseMap blockUsers; + + auto updateBlockUsersInfo = [&](UserInfo user) { + BlockUsersInfo &info = blockUsers[user.operation->getBlock()]; + info.users.push_back(user); + + if (auto addRef = dyn_cast(user.operation)) + info.addRefs.push_back(addRef); + if (auto dropRef = dyn_cast(user.operation)) + info.dropRefs.push_back(dropRef); + }; + + for (Operation *user : value.getUsers()) { + bool isAsyncUser = isa(user); + + while (user->getParentRegion() != definingRegion) { + updateBlockUsersInfo({user, isAsyncUser}); + user = user->getParentOp(); + isAsyncUser |= isa(user); + assert(user != nullptr && "value user lies outside of the value region"); + } + + updateBlockUsersInfo({user, isAsyncUser}); + } + + // Sort all operations found in the block. + auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & { + auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool { + return a->isBeforeInBlock(b); + }; + llvm::sort(info.addRefs, isBeforeInBlock); + llvm::sort(info.dropRefs, isBeforeInBlock); + llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool { + return isBeforeInBlock(a.operation, b.operation); + }); + + return info; + }; + + // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the + // blocks that modify the reference count of the `value`. + for (auto &kv : blockUsers) { + BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second); + + // Find all cancellable pairs first and erase them later to keep all + // pointers in the `info` valid until the end. + // + // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`. + llvm::SmallDenseMap cancellable; + + for (AddRefOp addRef : info.addRefs) { + for (DropRefOp dropRef : info.dropRefs) { + // `drop_ref` operation after the `add_ref` with matching count. + if (dropRef.count() != addRef.count() || + dropRef.getOperation()->isBeforeInBlock(addRef.getOperation())) + continue; + + // `drop_ref` was already marked for removal. + if (cancellable.find(dropRef.getOperation()) != cancellable.end()) + continue; + + // Check `value` users between `addRef` and `dropRef` in the `block`. + Operation *addRefOp = addRef.getOperation(); + Operation *dropRefOp = dropRef.getOperation(); + + // If there is a "regular" user after the `async.execute` user it is + // unsafe to erase cancellable reference counting operations pair, + // because async region can complete before the "regular" user and + // destroy the reference counted value. + bool hasExecuteUser = false; + bool unsafeToCancel = false; + + for (UserInfo &user : info.users) { + Operation *op = user.operation; + + // `user` operation lies after `addRef` ... + if (op == addRefOp || op->isBeforeInBlock(addRefOp)) + continue; + // ... and before `dropRef`. + if (op == dropRefOp || dropRefOp->isBeforeInBlock(op)) + break; + + bool isRegularUser = !user.hasExecuteUser; + bool isExecuteUser = user.hasExecuteUser; + + // It is unsafe to cancel `addRef` / `dropRef` pair. + if (isRegularUser && hasExecuteUser) { + unsafeToCancel = true; + break; + } + + hasExecuteUser |= isExecuteUser; + } + + // Mark the pair of reference counting operations for removal. + if (!unsafeToCancel) + cancellable[dropRef.getOperation()] = addRef.getOperation(); + + // If it us unsafe to cancel `addRef <-> dropRef` pair at this point, + // all the following pairs will be also unsafe. + break; + } + } + + // Erase all cancellable `addRef <-> dropRef` operation pairs. + for (auto &kv : cancellable) { + kv.first->erase(); + kv.second->erase(); + } + } + + return success(); +} + +void AsyncRefCountingOptimizationPass::runOnFunction() { + FuncOp func = getFunction(); + + // Optimize reference counting for values defined by block arguments. + WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { + for (BlockArgument arg : block->getArguments()) + if (isRefCounted(arg.getType())) + if (failed(optimizeReferenceCounting(arg))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) + signalPassFailure(); + + // Optimize reference counting for values defined by operation results. + WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { + for (unsigned i = 0; i < op->getNumResults(); ++i) + if (isRefCounted(op->getResultTypes()[i])) + if (failed(optimizeReferenceCounting(op->getResult(i)))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (opWalk.wasInterrupted()) + signalPassFailure(); +} + +std::unique_ptr> +mlir::createAsyncRefCountingOptimizationPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt index 9de43873039d00..dccae73d9bee59 100644 --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRAsyncTransforms AsyncParallelFor.cpp + AsyncRefCounting.cpp + AsyncRefCountingOptimization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Async diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 45a68fcba4a21b..969bea4a4549f0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -178,9 +178,6 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, Value shape = en.value(); SmallVector shapeRanges(map.getNumResults(), nullptr); for (auto en2 : llvm::enumerate(map.getResults())) { - auto dimExpr = en2.value().dyn_cast(); - if (!dimExpr) - continue; if (loopDepth == en2.value().cast().getPosition()) { LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " << loopDepth << "\n"); @@ -193,18 +190,49 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, llvm_unreachable("Expect to be able to extract a shape defining loop range"); } -/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges` -/// provides the loop range information for the fused loops. The rest are -/// obtained from the producer itself, since they are not tiled + fused. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, - const DenseMap &fusedLoopsAndRanges) { +/// Fuses the producer of `producerIdx` into the loop immediately enclosing +/// `consumer`. This is achieved by "recomputing" the `producer` at the time it +/// is needed just before the `consumer. +/// +/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are +/// 2 cases: +/// 1. Buffer case: `producerIdx` is the index of the buffer in +/// `producer.getOutputBuffers()`. +/// 2. Tensor case: `producerIdx` is the index of the tensor in +/// `producer.getResults()`. +static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, + LinalgOp consumer, unsigned consumerIdx) { + Operation *shapeProducingOp = + consumer.getShapedOperand(consumerIdx).getDefiningOp(); + assert((isa(shapeProducingOp) || + isa(shapeProducingOp)) && + "SubviewOp or SubTensorOp expected"); + + // loopToOperandRangesMaps are permutations-only by construction: + // we can always identify a data dimension with a (at least one) loop + // dimension. + // TODO: extend this with range inference. + AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); + LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx + << ", producer map: " << producerMap << "\n"); unsigned nPar = producer.getNumParallelLoops(); unsigned nRed = producer.getNumReductionLoops(); unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); - for (auto fusedLoops : fusedLoopsAndRanges) - loopRanges[fusedLoops.first] = fusedLoops.second; + + // Iterate over dimensions identified by the producer map for `producerIdx`. + // This defines a subset of the loop ranges that we need to complete later. + auto loc = consumer.getLoc(); + for (auto en : llvm::enumerate(producerMap.getResults())) { + unsigned posInProducerLoop = en.value().cast().getPosition(); + loopRanges[posInProducerLoop] = + isa(shapeProducingOp) + ? cast(shapeProducingOp) + .getOrCreateRanges(b, loc)[en.index()] + : cast(shapeProducingOp) + .getOrCreateRanges(b, loc)[en.index()]; + } // Iterate over all dimensions. For the dimensions not identified by the // producer map for `producerIdx`, we need to explicitly compute the shape @@ -222,45 +250,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, } } - return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges); -} - -/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is -/// expected to be defined by a subview op or a subtensor op. -static Range getRangeFromOperandShape(OpBuilder &b, Location loc, - Value shapedOperand, unsigned dim) { - Operation *shapeProducingOp = shapedOperand.getDefiningOp(); - if (auto subViewOp = dyn_cast(shapeProducingOp)) - return subViewOp.getOrCreateRanges(b, loc)[dim]; - if (auto subTensorOp = dyn_cast(shapeProducingOp)) - return subTensorOp.getOrCreateRanges(b, loc)[dim]; - llvm_unreachable("SubviewOp or SubTensorOp expected"); -} - -/// Fuses the producer of `producerIdx` into the loop immediately enclosing -/// `consumer`. This is achieved by "recomputing" the `producer` at the time it -/// is needed just before the `consumer. -/// -/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are -/// 2 cases: -/// 1. Buffer case: `producerIdx` is the index of the buffer in -/// `producer.getOutputBuffers()`. -/// 2. Tensor case: `producerIdx` is the index of the tensor in -/// `producer.getResults()`. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, - LinalgOp consumer, unsigned consumerIdx) { - AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); - LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx - << ", producer map: " << producerMap << "\n"); - DenseMap fusedLoopsAndRanges; - Location loc = consumer.getLoc(); - Value shapedOperand = consumer.getShapedOperand(consumerIdx); - for (auto en : llvm::enumerate(producerMap.getResults())) { - unsigned posInProducerLoop = en.value().cast().getPosition(); - fusedLoopsAndRanges[posInProducerLoop] = - getRangeFromOperandShape(b, loc, shapedOperand, en.index()); - } - return fuse(b, producer, fusedLoopsAndRanges); + return cloneWithLoopRanges(b, loc, producer, loopRanges); } // Encode structural fusion safety preconditions. @@ -531,68 +521,9 @@ static AffineMap pruneReductionDimsFromMap(ArrayRef iteratorTypes, return getProjectedMap(map, projectedDims); } -/// Returns the mapping from iterations in the consumer that write to the same -/// location as the iterations in the producer. To do so use -/// - indexing map of the fused view in the consumer : consumerIndexMap -/// - indexing map of the fused view in the producer : producerIndexMap -/// consumerLoopToProducerLoop = -/// inverse(producerIndexMap).compose(consumerIndexMap) -static Optional getConsumerLoopToProducerLoopMap( - LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { - auto producer = cast(dependence.dependentOpView.op); - AffineMap producerIndexingMap = - producer.getIndexingMap(dependence.dependentOpView.operandIndex); - auto consumer = cast(dependence.indexingOpView.op); - AffineMap consumerIndexingMap = - consumer.getIndexingMap(dependence.indexingOpView.operandIndex); - - AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( - producer.iterator_types().getValue(), producerIndexingMap); - if (!prunedProducerIndexingMap.isPermutation()) - return None; - - if (consumerIndexingMap.getNumResults() != - prunedProducerIndexingMap.getNumResults()) - return None; - - LLVM_DEBUG({ - llvm::dbgs() << "\t producerMap : "; - producerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << " pruned : "; - prunedProducerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - llvm::dbgs() << "\t consumerMap : "; - consumerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); - if (!invProducerIndexMap) - return None; - - return invProducerIndexMap.compose(consumerIndexingMap); -} - -/// Given a projected permutation `map`, returns true if the map changes the -/// order in which the fused loop dimension appear. -static bool doesTransposeAccess(AffineMap map, - const std::set &fusableLoops) { - Optional lastFusableLoop; - for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { - return expr.cast().getPosition(); - })) { - if (!fusableLoops.count(pos)) - continue; - if (!lastFusableLoop) { - lastFusableLoop = pos; - continue; - } - if (pos <= lastFusableLoop.getValue()) - return true; - lastFusableLoop = pos; - } - return false; -} +using FusableOpDependencesTy = llvm::MapVector< + Operation *, + SmallVector>; /// Returns the positions of the loop in `op` that can be tiled based on the /// operations that are to be fused with it. For example, in a @@ -607,7 +538,13 @@ static bool doesTransposeAccess(AffineMap map, /// 2. Of the parallel loops only some can be fused. Only those loops can be /// fused such where the fusable loops iteration space only touches one tile /// of the fused operation. This is because the producer (which is writing -/// the fused subview) has update semantics. +/// the fused subview) has update semantics. To compute this, +/// a. Find the mapping from iterations in the consumer that write to the +/// same location as the iterations in the producer. To do so use +/// - indexing map of the fused view in the consumer : consumerIndexMap +/// - indexing map of the fused view in the producer : producerIndexMap +/// consumerLoopToProducerLoop = +/// inverse(producerIndexMap).compose(consumerIndexMap) /// /// Since an inverse computation is needed, we need to consider the projection /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops @@ -645,9 +582,8 @@ static bool doesTransposeAccess(AffineMap map, /// submap with only parallel loops = affine_map<(i, j) -> (j)> /// Fused dimensions : j static std::set -collectFusableLoops(ArrayRef ops, - const FusableOpDependencesTy &fusableDependences) { - assert(!ops.empty()); +collectTileAndFuseLoops(LinalgOp op, + const FusableOpDependencesTy &fusableDependences) { auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { return linalgOp.iterator_types() .getValue() @@ -658,245 +594,289 @@ collectFusableLoops(ArrayRef ops, .size(); }; - size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); - for (auto op : ops.drop_back()) { + LLVM_DEBUG({ + llvm::dbgs() << "Op : "; + op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); + llvm::dbgs() << "\n"; + }); + + size_t numOuterParallelLoops = getNumOuterParallelLoops(op); + for (auto dependence : fusableDependences) { + linalg::LinalgOp producer = cast(dependence.first); numOuterParallelLoops = - std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); + std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer)); } std::set fusableLoops; auto range = llvm::seq(0, numOuterParallelLoops); fusableLoops.insert(range.begin(), range.end()); - - for (auto op : reverse(ops)) { - for (auto dependence : fusableDependences.lookup(op)) { - LLVM_DEBUG({ - llvm::dbgs() << "\t fusable :"; - for (unsigned i : fusableLoops) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - - Optional consumerLoopToProducerLoop = - getConsumerLoopToProducerLoopMap(dependence); - if (!consumerLoopToProducerLoop) { - op.emitRemark("failed to get map from consumer loop to producer loop"); - return {}; - } - // todo: This condition is only an implementation limitation. When fusing - // the operation, if the accesses in the producer/consumer are transposes - // of each other, the loop bounds for the tiled producer can be - // manipulated accordingly. This requires some additional bookkeeping in - // the implementation of tile+fuse that is defered to later. - if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { - op.emitRemark("unhandled fusion when fusion requires permutation"); - return {}; - } - - std::set candidates; - for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { - unsigned position = expr.cast().getPosition(); - if (fusableLoops.count(position)) - candidates.insert(position); - } - LLVM_DEBUG({ - llvm::dbgs() << "\t candidates :"; - for (unsigned i : candidates) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - if (candidates.empty()) - return {}; - std::swap(candidates, fusableLoops); + for (auto dependence : fusableDependences) { + LLVM_DEBUG({ + llvm::dbgs() << "\t fusable :"; + for (unsigned i : fusableLoops) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + linalg::LinalgOp producer = cast(dependence.first); + + assert(!dependence.second.empty() && + "unexpected producer but not dependences"); + AffineMap producerIndexingMap = producer.getIndexingMap( + dependence.second.front().dependentOpView.operandIndex); + AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( + producer.iterator_types().getValue(), producerIndexingMap); + if (!prunedProducerIndexingMap.isPermutation()) + return {}; + + AffineMap consumerIndexingMap = op.getIndexingMap( + dependence.second.front().indexingOpView.operandIndex); + if (consumerIndexingMap.getNumResults() != + prunedProducerIndexingMap.getNumResults()) + return {}; + + LLVM_DEBUG({ + llvm::dbgs() << "\t producerMap : "; + producerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << " pruned : "; + prunedProducerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + llvm::dbgs() << "\t consumerMap : "; + consumerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + AffineMap invProducerIndexMap = + inversePermutation(prunedProducerIndexingMap); + if (!invProducerIndexMap) + return {}; + + AffineMap consumerLoopToProducerLoop = + invProducerIndexMap.compose(consumerIndexingMap); + + LLVM_DEBUG({ + llvm::dbgs() << "\t consumerLoopToProducerLoop : "; + consumerLoopToProducerLoop.print(llvm::dbgs()); + }); + + std::set candidates; + for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { + AffineDimExpr dimExpr = expr.dyn_cast(); + if (!dimExpr) + continue; + unsigned position = dimExpr.getPosition(); + if (fusableLoops.count(position)) + candidates.insert(position); } + LLVM_DEBUG({ + llvm::dbgs() << "\t candidates :"; + for (unsigned i : candidates) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + if (candidates.empty()) + return {}; + std::swap(candidates, fusableLoops); } return fusableLoops; } -/// Find all dependences that are fusable. -FusableOpDependencesTy mlir::linalg::findAllFusableDependences( - ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { +/// Find all dependences that are to be fusable. +static FusableOpDependencesTy +findAllFusableDependences(LinalgOp op, + const LinalgDependenceGraph &dependenceGraph, + const LinalgFusionOptions &fusionOptions) { FusableOpDependencesTy fusableDependences; // TODO: Currently fusion would not be legal if the fusable dependence is to // the same producer but different indexing map in the consumer. Fix this, but // in the meanwhile disallow such a fusion. DenseMap fusedProducerIndexingMap; - for (LinalgOp op : reverse(ops)) { - for (auto operandIndex : - llvm::seq(0, op.getNumInputsAndOutputBuffers())) { - Optional - fusableDependence = - findFusableProducer(op, operandIndex, dependenceGraph); - if (!fusableDependence) - continue; - LinalgOp producerOp = - cast(fusableDependence->dependentOpView.op); - // Do not fuse dependences that are to operations not in the same basic - // block. This avoid moving fused operations across loops that might - // themselves carry dependency making the fusion illegal. - if (producerOp.getOperation()->getBlock() != - op.getOperation()->getBlock()) { - op.emitRemark("unhandled fusion of ops in different basic blocks"); - return FusableOpDependencesTy{}; - } - // Make sure that the indexing map of the view used for fusion in the - // producer is a projected permutation. - unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; - AffineMap producerMap = producerOp.getIndexingMap(producerIdx); - if (!producerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled non permutation indexing map for fused view in " - "producer for operand at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } - - unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex; - AffineMap consumerMap = op.getIndexingMap(consumerIdx); - if (!consumerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled case where indexing map for fused view in the consumer " - "is " - "not a projected permuration while fusing at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } + for (auto operandIndex : fusionOptions.indicesToFuse) { + auto fusableDependence = + findFusableProducer(op, operandIndex, dependenceGraph); + if (!fusableDependence) + return FusableOpDependencesTy{}; + LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); + // Do not fuse dependences that are to operations not in the same basic + // block. This avoid moving fused operations across loops that might + // themselves carry dependency making the fusion illegal. + if (producerOp.getOperation()->getBlock() != + op.getOperation()->getBlock()) { + op.emitRemark("unhandled fusion of ops in different basic blocks"); + return FusableOpDependencesTy{}; + } + // Make sure that the indexing map of the view used for fusion in the + // producer is a projected permutation. + unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; + AffineMap producerMap = producerOp.getIndexingMap(producerIdx); + if (!producerMap.isProjectedPermutation()) { + op.emitRemark("unhandled non permutation indexing map for fused view in " + "producer for operand at index ") + << operandIndex; + return FusableOpDependencesTy{}; + } - // Check if the producer is already a fusion candidate. Cannot fuse this - // dependence if it has a different indexing map when used in the - // consumer. - if (fusedProducerIndexingMap.count(producerOp.getOperation()) && - fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { - op.emitRemark( - "unhandled fusion to the same producer but with different " - "indexing maps"); - return FusableOpDependencesTy{}; - } - fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; + unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex; + AffineMap consumerMap = op.getIndexingMap(consumerIdx); + if (!consumerMap.isProjectedPermutation()) { + op.emitRemark( + "unhandled case where indexing map for fused view in the consumer is " + "not a projected permutation while fusing at index ") + << operandIndex; + return FusableOpDependencesTy{}; + } - fusableDependences[producerOp.getOperation()].push_back( - *fusableDependence); + // Check if the producer is already a fusion candidate. Cannot fuse this + // dependence if it has a different indexing map when used in the consumer. + if (fusedProducerIndexingMap.count(producerOp.getOperation()) && + fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { + op.emitRemark("unhandled fusion to the same producer but with different " + "indexing maps"); + return FusableOpDependencesTy{}; } + fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; + + fusableDependences[producerOp.getOperation()].push_back(*fusableDependence); } return fusableDependences; } -/// Tile the fused loops in the root operation, by setting the tile sizes for -/// all other loops to zero (those will be tiled later). -static Optional tileRootOperation( - OpBuilder &builder, LinalgOp op, ArrayRef tileSizeVector, - const LinalgTilingOptions &options, const std::set &fusedLoops) { - SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); - auto zero = std_constant_index(0); - for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) - if (!fusedLoops.count(i)) - tileSizes[i] = zero; - LinalgTilingOptions tileFusedLoopsOptions = options; - tileFusedLoopsOptions.setTileSizes(tileSizes); - return tileLinalgOp(builder, op, tileFusedLoopsOptions); -} - -/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected -/// to be a tiled operation such that it is valid to fuse all operations in -/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of -/// `tiledOp`. -static SmallVector -fuseOperations(OpBuilder &builder, LinalgOp tiledOp, - ArrayRef fusionCandidates, - const FusableOpDependencesTy &fusableDependences, - const std::set &fusedLoops) { - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(tiledOp); - DenseMap fusedLoopsAndRanges; - for (unsigned loop : fusedLoops) { - ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop); - fusedLoopsAndRanges[loop] = getRangeFromOperandShape( - builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); - } - - SmallVector fusedOps(fusionCandidates.size()); - for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { - LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges); - fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; - builder.setInsertionPoint(fusedOp); - } - return fusedOps; +static bool isZero(Value v) { + if (auto cst = v.getDefiningOp()) + return cst.getValue() == 0; + return false; } template static Optional -tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, +tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions) { - if (ops.empty()) - return llvm::None; - LinalgOp rootOp = ops.back(); - for (auto op : enumerate(ops)) { - // TODO: Nothing in the fusion of sequence of ops is specific to - // buffers. This check can be removed after it is tested on tensors. - LinalgOp linalgOp = op.value(); - if (!linalgOp.hasBufferSemantics()) { - linalgOp.emitError("tile and fuse only tested for buffer operation"); - return llvm::None; - } - } - // TODO: Support interchange with tile + fuse. This might actually help do - // better fusion. + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + // Some of the tiling options might not be supportable with tile and fuse. + // TODO: Support interchange with tile + fuse. if (!tilingOptions.interchangeVector.empty()) { - rootOp.emitError("unable to handle tile and fuse with interchange"); + op.emitError("unable to handle tile and fuse with interchange"); return llvm::None; } - OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(rootOp); - ScopedContext scope(builder, rootOp.getLoc()); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + ScopedContext scope(rewriter, op.getLoc()); // Find all the producers. FusableOpDependencesTy fusableDependences = - findAllFusableDependences(ops, dependenceGraph); + findAllFusableDependences(op, dependenceGraph, fusionOptions); if (fusableDependences.empty()) return llvm::None; + // Enforce the convention that "tiling by zero" skips tiling a particular + // dimension. This convention is significantly simpler to handle instead of + // adjusting affine maps to account for missing dimensions. + auto nLoops = op.getNumLoops(); + SmallVector tileSizeVector = + tilingOptions.tileSizeComputationFunction(rewriter, op); + if (tileSizeVector.size() < nLoops) { + auto zero = std_constant_index(0); + tileSizeVector.append(nLoops - tileSizeVector.size(), zero); + } + TiledAndFusedLinalgOps ret; + // Find the loops that can be tiled and fused. - ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); + std::set tileFuseLoops = + collectTileAndFuseLoops(op, fusableDependences); // If there are no fusable dependences or there are no tile+fusable loops, // just return. - if (ret.fusedLoopDims.empty()) { + if (tileFuseLoops.empty()) { return llvm::None; } - // Tile the fused loops in the last operation in the list. - SmallVector tileSizeVector = - tilingOptions.tileSizeComputationFunction(builder, rootOp); - Optional tiledRootOp = tileRootOperation( - builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); - if (!tiledRootOp) { - rootOp.emitError("failed to tile the fused loops"); + // Get the tile sizes for the first and second tiling steps. For the first + // step the tile size are set to zero for the loops that arent + // fused. Similarly for the second step, the tile sizes are set to zero for + // the loops that are fused. For example, if for the following input + // + // ``` + // linalg.add ins(%a, %b) outs(%c) + // linalg.matmul ins(%d, %c) outs(%e) + // ``` + // + // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` + // respectively, and since only `j` can be tiled and fused. The tile sizes + // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable + // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile + // the tiled matmul generated by the first tiling step. + SmallVector tileAndFuseSizes, tileSizes; + for (auto tileSize : enumerate(tileSizeVector)) { + auto zero = std_constant_index(0); + if (tileFuseLoops.count(tileSize.index())) { + tileAndFuseSizes.push_back(tileSize.value()); + tileSizes.push_back(zero); + } else { + tileSizes.push_back(tileSize.value()); + tileAndFuseSizes.push_back(zero); + } + } + + // Tile for the loops that can be fused. + LinalgTilingOptions firstTilingOptions = tilingOptions; + firstTilingOptions.setTileSizes(tileAndFuseSizes); + Optional firstTiledOp = + tileLinalgOp(rewriter, op, firstTilingOptions); + if (!firstTiledOp) return llvm::None; + ret.op = firstTiledOp->op; + ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); + + rewriter.setInsertionPoint(ret.op); + // Fuse the operands. + for (auto dependence : fusableDependences) { + LinalgOp producerOp = cast(dependence.first); + unsigned producerIdx = + dependence.second.front().dependentOpView.operandIndex; + unsigned consumerIdx = + dependence.second.front().indexingOpView.operandIndex; + LinalgOp fusedOp = fuse(rewriter, producerOp, + producerOp.getOutputIndex(producerIdx).getValue(), + ret.op, consumerIdx); + ret.fusedProducers.push_back(fusedOp); + ret.originalProducers.push_back(producerOp); + } + + if (!llvm::all_of(tileSizes, isZero)) { + // Tile the remaining loops of the root operation. + LinalgTilingOptions secondTilingOptions = tilingOptions; + // The distribution is done only for the tile+fused loops. + secondTilingOptions.distribution = llvm::None; + secondTilingOptions.setTileSizes(tileSizes); + Optional secondTiledOp = + tileLinalgOp(rewriter, ret.op, secondTilingOptions); + if (!secondTiledOp) + return llvm::None; + ret.unfusedLoops.assign(secondTiledOp->loops.begin(), + secondTiledOp->loops.end()); + rewriter.eraseOp(ret.op); + ret.op = secondTiledOp->op; } - ret.op = tiledRootOp->op; - ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); - // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(), - fusableDependences, ret.fusedLoopDims); return ret; } Optional -mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, +mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions) { + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions) { switch (tilingOptions.loopType) { case LinalgTilingLoopType::Loops: - return tileAndFuseLinalgOpsImpl(builder, ops, dependenceGraph, - tilingOptions); + return tileAndFuseLinalgOpsImpl(rewriter, op, dependenceGraph, + tilingOptions, fusionOptions); case LinalgTilingLoopType::ParallelLoops: return tileAndFuseLinalgOpsImpl( - builder, ops, dependenceGraph, tilingOptions); + rewriter, op, dependenceGraph, tilingOptions, fusionOptions); default:; } return llvm::None; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index a824f6eb620f0e..e002336ed1c65f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -166,6 +166,11 @@ struct LinalgOpInstancePromotionOptions { /// Alignment of promoted buffer. Optional alignment; }; + +struct PromotionInfo { + Value fullLocalView; + Value partialLocalView; +}; } // namespace LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( @@ -228,10 +233,10 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( // To account for general boundary effects, padding must be performed on the // boundary tiles. For now this is done with an unconditional `fill` op followed // by a partial `copy` op. -Optional mlir::linalg::promoteSubviewAsNewBuffer( - OpBuilder &b, Location loc, SubViewOp subView, - AllocBufferCallbackFn allocationFn, OperationFolder *folder) { - ScopedContext scopedContext(b, loc); +static Optional +promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, SubViewOp subView, + LinalgOpInstancePromotionOptions const &options, + OperationFolder *folder) { auto viewType = subView.getType(); auto rank = viewType.getRank(); SmallVector fullSizes, partialSizes; @@ -249,7 +254,8 @@ Optional mlir::linalg::promoteSubviewAsNewBuffer( SmallVector dynSizes(fullSizes.size(), -1); // If a callback is not specified, then use the default implementation for // allocating the promoted buffer. - Optional fullLocalView = allocationFn(b, subView, fullSizes, folder); + Optional fullLocalView = + options.allocationFn(b, subView, fullSizes, folder); if (!fullLocalView) return {}; auto zero = folded_std_constant_index(folder, 0); @@ -273,8 +279,8 @@ promoteSubViews(OpBuilder &b, Location loc, for (auto v : options.subViews) { SubViewOp subView = cast(v.second.getDefiningOp()); - Optional promotionInfo = promoteSubviewAsNewBuffer( - b, loc, subView, options.allocationFn, folder); + Optional promotionInfo = + promoteSubviewAsNewBuffer(b, loc, subView, options, folder); if (!promotionInfo) return {}; promotionInfoMap[v.first] = *promotionInfo; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index a855c07cb8d47e..836cc28e0a47f3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -165,69 +165,17 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( if (!linalgOp.hasBufferSemantics()) return failure(); - DenseSet producers; - producers.insert(linalgOp); - for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) { - if (!fusionOptions.indicesToFuse.count( - dependence.indexingOpView.operandIndex)) - continue; - if (isa(dependence.dependentOpView.op)) - producers.insert(dependence.dependentOpView.op); - } - - SmallVector fusionOps; - for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; - ++it) { - auto producerLinalgOp = dyn_cast(&(*it)); - if (producerLinalgOp && producers.count(producerLinalgOp)) - fusionOps.push_back(producerLinalgOp); - } - fusionOps.push_back(linalgOp); - - SmallVector tileSizes = - tilingOptions.tileSizeComputationFunction(rewriter, op); - LinalgTilingOptions instanceTilingOptions = tilingOptions; - instanceTilingOptions.setTileSizes(tileSizes); Optional tiledAndFusedOps = tileAndFuseLinalgOps( - rewriter, fusionOps, dependenceGraph, instanceTilingOptions); + rewriter, op, dependenceGraph, tilingOptions, fusionOptions); if (!tiledAndFusedOps) return failure(); - - // Tile the unfused loops; - SmallVector unfusedLoopTileSizes; - Value zero = rewriter.create(op->getLoc(), 0); - for (auto tileSize : enumerate(tileSizes)) { - if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) - unfusedLoopTileSizes.push_back(zero); - else - unfusedLoopTileSizes.push_back(tileSize.value()); - } - // Tile the loop only if there is a non-zero tile size. - if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) - unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); - if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { - if (auto cst = val.getDefiningOp()) - return cst.getValue() != 0; - return true; - })) { - LinalgTilingOptions unfusedTilingOptions = tilingOptions; - unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); - Optional unfusedTiledOp = - tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); - if (!unfusedTiledOp) - return failure(); - rewriter.eraseOp(tiledAndFusedOps->op); - tiledAndFusedOps->op = unfusedTiledOp->op; - } - marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); for (auto fusedOp : tiledAndFusedOps->fusedProducers) { fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); } - for (auto origProducerOp : ArrayRef(fusionOps).drop_back()) { + for (auto origProducerOp : tiledAndFusedOps->originalProducers) originalOpMarker.replaceLinalgMarker(rewriter, origProducerOp.getOperation()); - } rewriter.updateRootInPlace( op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); return success(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 342d73273dd64c..04efc25a92ee4b 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -916,17 +916,41 @@ bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, llvm_unreachable("unknown comparison predicate"); } +// Returns true if the predicate is true for two equal operands. +static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) { + switch (predicate) { + case CmpIPredicate::eq: + case CmpIPredicate::sle: + case CmpIPredicate::sge: + case CmpIPredicate::ule: + case CmpIPredicate::uge: + return true; + case CmpIPredicate::ne: + case CmpIPredicate::slt: + case CmpIPredicate::sgt: + case CmpIPredicate::ult: + case CmpIPredicate::ugt: + return false; + } + llvm_unreachable("unknown comparison predicate"); +} + // Constant folding hook for comparisons. OpFoldResult CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two arguments"); + if (lhs() == rhs()) { + auto val = applyCmpPredicateToEqualOperands(getPredicate()); + return BoolAttr::get(val, getContext()); + } + auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); + return BoolAttr::get(val, getContext()); } //===----------------------------------------------------------------------===// @@ -2269,6 +2293,30 @@ OpFoldResult LoadOp::fold(ArrayRef cstOperands) { return OpFoldResult(); } +namespace { +/// Fold a load on a tensor_to_memref operation into an extract_element on the +/// corresponding tensor. +struct LoadOfTensorToMemref : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoadOp load, + PatternRewriter &rewriter) const override { + auto tensorToMemref = load.memref().getDefiningOp(); + if (!tensorToMemref) + return failure(); + + rewriter.replaceOpWithNewOp(load, tensorToMemref.tensor(), + load.indices()); + return success(); + } +}; +} // end anonymous namespace. + +void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // MemRefCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp index 332c7ff1e2b974..f769965b26ece3 100644 --- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp +++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp @@ -16,6 +16,7 @@ #ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS #include +#include #include #include #include @@ -27,30 +28,141 @@ // Async runtime API. //===----------------------------------------------------------------------===// -struct AsyncToken { - bool ready = false; +namespace { + +// Forward declare class defined below. +class RefCounted; + +// -------------------------------------------------------------------------- // +// AsyncRuntime orchestrates all async operations and Async runtime API is built +// on top of the default runtime instance. +// -------------------------------------------------------------------------- // + +class AsyncRuntime { +public: + AsyncRuntime() : numRefCountedObjects(0) {} + + ~AsyncRuntime() { + assert(getNumRefCountedObjects() == 0 && + "all ref counted objects must be destroyed"); + } + + int32_t getNumRefCountedObjects() { + return numRefCountedObjects.load(std::memory_order_relaxed); + } + +private: + friend class RefCounted; + + // Count the total number of reference counted objects in this instance + // of an AsyncRuntime. For debugging purposes only. + void addNumRefCountedObjects() { + numRefCountedObjects.fetch_add(1, std::memory_order_relaxed); + } + void dropNumRefCountedObjects() { + numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed); + } + + std::atomic numRefCountedObjects; +}; + +// Returns the default per-process instance of an async runtime. +AsyncRuntime *getDefaultAsyncRuntimeInstance() { + static auto runtime = std::make_unique(); + return runtime.get(); +} + +// -------------------------------------------------------------------------- // +// A base class for all reference counted objects created by the async runtime. +// -------------------------------------------------------------------------- // + +class RefCounted { +public: + RefCounted(AsyncRuntime *runtime, int32_t refCount = 1) + : runtime(runtime), refCount(refCount) { + runtime->addNumRefCountedObjects(); + } + + virtual ~RefCounted() { + assert(refCount.load() == 0 && "reference count must be zero"); + runtime->dropNumRefCountedObjects(); + } + + RefCounted(const RefCounted &) = delete; + RefCounted &operator=(const RefCounted &) = delete; + + void addRef(int32_t count = 1) { refCount.fetch_add(count); } + + void dropRef(int32_t count = 1) { + int32_t previous = refCount.fetch_sub(count); + assert(previous >= count && "reference count should not go below zero"); + if (previous == count) + destroy(); + } + +protected: + virtual void destroy() { delete this; } + +private: + AsyncRuntime *runtime; + std::atomic refCount; +}; + +} // namespace + +struct AsyncToken : public RefCounted { + // AsyncToken created with a reference count of 2 because it will be returned + // to the `async.execute` caller and also will be later on emplaced by the + // asynchronously executed task. If the caller immediately will drop its + // reference we must ensure that the token will be alive until the + // asynchronous operation is completed. + AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {} + + // Internal state below guarded by a mutex. std::mutex mu; std::condition_variable cv; + + bool ready = false; std::vector> awaiters; }; -struct AsyncGroup { - std::atomic pendingTokens{0}; - std::atomic rank{0}; +struct AsyncGroup : public RefCounted { + AsyncGroup(AsyncRuntime *runtime) + : RefCounted(runtime), pendingTokens(0), rank(0) {} + + std::atomic pendingTokens; + std::atomic rank; + + // Internal state below guarded by a mutex. std::mutex mu; std::condition_variable cv; + std::vector> awaiters; }; +// Adds references to reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) { + RefCounted *refCounted = static_cast(ptr); + refCounted->addRef(count); +} + +// Drops references from reference counted runtime object. +extern "C" MLIR_ASYNCRUNTIME_EXPORT void +mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) { + RefCounted *refCounted = static_cast(ptr); + refCounted->dropRef(count); +} + // Create a new `async.token` in not-ready state. extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() { - AsyncToken *token = new AsyncToken; + AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance()); return token; } // Create a new `async.group` in empty state. extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() { - AsyncGroup *group = new AsyncGroup; + AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance()); return group; } @@ -59,23 +171,34 @@ mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) { std::unique_lock lockToken(token->mu); std::unique_lock lockGroup(group->mu); + // Get the rank of the token inside the group before we drop the reference. + int rank = group->rank.fetch_add(1); group->pendingTokens.fetch_add(1); - auto onTokenReady = [group]() { + auto onTokenReady = [group, token](bool dropRef) { // Run all group awaiters if it was the last token in the group. if (group->pendingTokens.fetch_sub(1) == 1) { group->cv.notify_all(); for (auto &awaiter : group->awaiters) awaiter(); } + + // We no longer need the token or the group, drop references on them. + if (dropRef) { + group->dropRef(); + token->dropRef(); + } }; - if (token->ready) - onTokenReady(); - else - token->awaiters.push_back([onTokenReady]() { onTokenReady(); }); + if (token->ready) { + onTokenReady(false); + } else { + group->addRef(); + token->addRef(); + token->awaiters.push_back([onTokenReady]() { onTokenReady(true); }); + } - return group->rank.fetch_add(1); + return rank; } // Switches `async.token` to ready state and runs all awaiters. @@ -85,6 +208,10 @@ extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) { token->cv.notify_all(); for (auto &awaiter : token->awaiters) awaiter(); + + // Async tokens created with a ref count `2` to keep token alive until the + // async task completes. Drop this reference explicitly when token emplaced. + token->dropRef(); } extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) { @@ -114,14 +241,18 @@ extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroResume resume) { std::unique_lock lock(token->mu); - auto execute = [handle, resume]() { + auto execute = [handle, resume, token](bool dropRef) { + if (dropRef) + token->dropRef(); mlirAsyncRuntimeExecute(handle, resume); }; - if (token->ready) - execute(); - else - token->awaiters.push_back([execute]() { execute(); }); + if (token->ready) { + execute(false); + } else { + token->addRef(); + token->awaiters.push_back([execute]() { execute(true); }); + } } extern "C" MLIR_ASYNCRUNTIME_EXPORT void @@ -129,14 +260,18 @@ mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle, CoroResume resume) { std::unique_lock lock(group->mu); - auto execute = [handle, resume]() { + auto execute = [handle, resume, group](bool dropRef) { + if (dropRef) + group->dropRef(); mlirAsyncRuntimeExecute(handle, resume); }; - if (group->pendingTokens == 0) - execute(); - else - group->awaiters.push_back([execute]() { execute(); }); + if (group->pendingTokens == 0) { + execute(false); + } else { + group->addRef(); + group->awaiters.push_back([execute]() { execute(true); }); + } } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir index 1fd71a65379e21..dadb28dbc08218 100644 --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -1,5 +1,20 @@ // RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s +// CHECK-LABEL: reference_counting +func @reference_counting(%arg0: !async.token) { + // CHECK: %[[C2:.*]] = constant 2 : i32 + // CHECK: call @mlirAsyncRuntimeAddRef(%arg0, %[[C2]]) + async.add_ref %arg0 {count = 2 : i32} : !async.token + + // CHECK: %[[C1:.*]] = constant 1 : i32 + // CHECK: call @mlirAsyncRuntimeDropRef(%arg0, %[[C1]]) + async.drop_ref %arg0 {count = 1 : i32} : !async.token + + return +} + +// ----- + // CHECK-LABEL: execute_no_async_args func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { // CHECK: %[[TOKEN:.*]] = call @async_execute_fn(%arg0, %arg1) diff --git a/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir new file mode 100644 index 00000000000000..6500fa0b1d8aba --- /dev/null +++ b/mlir/test/Dialect/Async/async-ref-counting-optimization.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt %s -async-ref-counting-optimization | FileCheck %s + +// CHECK-LABEL: @cancellable_operations_0 +func @cancellable_operations_0(%arg0: !async.token) { + // CHECK-NOT: async.add_ref + // CHECK-NOT: async.drop_ref + async.add_ref %arg0 {count = 1 : i32} : !async.token + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @cancellable_operations_1 +func @cancellable_operations_1(%arg0: !async.token) { + // CHECK-NOT: async.add_ref + // CHECK: async.execute + async.add_ref %arg0 {count = 1 : i32} : !async.token + async.execute [%arg0] { + // CHECK: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK-NEXT: async.yield + async.yield + } + // CHECK-NOT: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @cancellable_operations_2 +func @cancellable_operations_2(%arg0: !async.token) { + // CHECK: async.await + // CHECK-NEXT: async.await + // CHECK-NEXT: async.await + // CHECK-NEXT: return + async.add_ref %arg0 {count = 1 : i32} : !async.token + async.await %arg0 : !async.token + async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.await %arg0 : !async.token + async.add_ref %arg0 {count = 1 : i32} : !async.token + async.await %arg0 : !async.token + async.drop_ref %arg0 {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @cancellable_operations_3 +func @cancellable_operations_3(%arg0: !async.token) { + // CHECK-NOT: add_ref + async.add_ref %arg0 {count = 1 : i32} : !async.token + %token = async.execute { + async.await %arg0 : !async.token + // CHECK: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.yield + } + // CHECK-NOT: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.await + async.await %arg0 : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @not_cancellable_operations_0 +func @not_cancellable_operations_0(%arg0: !async.token, %arg1: i1) { + // It is unsafe to cancel `add_ref` / `drop_ref` pair because it is possible + // that the body of the `async.execute` operation will run before the await + // operation in the function body, and will destroy the `%arg0` token. + // CHECK: add_ref + async.add_ref %arg0 {count = 1 : i32} : !async.token + %token = async.execute { + // CHECK: async.await + async.await %arg0 : !async.token + // CHECK: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: async.yield + async.yield + } + // CHECK: async.await + async.await %arg0 : !async.token + // CHECK: drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} + +// CHECK-LABEL: @not_cancellable_operations_1 +func @not_cancellable_operations_1(%arg0: !async.token, %arg1: i1) { + // Same reason as above, although `async.execute` is inside the nested + // region or "regular" opeation. + // + // NOTE: This test is not correct w.r.t. reference counting, and at runtime + // would leak %arg0 value if %arg1 is false. IR like this will not be + // constructed by automatic reference counting pass, because it would + // place `async.add_ref` right before the `async.execute` inside `scf.if`. + + // CHECK: async.add_ref + async.add_ref %arg0 {count = 1 : i32} : !async.token + scf.if %arg1 { + %token = async.execute { + async.await %arg0 : !async.token + // CHECK: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + async.yield + } + } + // CHECK: async.await + async.await %arg0 : !async.token + // CHECK: async.drop_ref + async.drop_ref %arg0 {count = 1 : i32} : !async.token + // CHECK: return + return +} diff --git a/mlir/test/Dialect/Async/async-ref-counting.mlir b/mlir/test/Dialect/Async/async-ref-counting.mlir new file mode 100644 index 00000000000000..504a18fba9901a --- /dev/null +++ b/mlir/test/Dialect/Async/async-ref-counting.mlir @@ -0,0 +1,253 @@ +// RUN: mlir-opt %s -async-ref-counting | FileCheck %s + +// CHECK-LABEL: @cond +func private @cond() -> i1 + +// CHECK-LABEL: @token_arg_no_uses +func @token_arg_no_uses(%arg0: !async.token) { + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + return +} + +// CHECK-LABEL: @token_arg_conditional_await +func @token_arg_conditional_await(%arg0: !async.token, %arg1: i1) { + cond_br %arg1, ^bb1, ^bb2 +^bb1: + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + return +^bb2: + // CHECK: async.await %arg0 + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + async.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @token_no_uses +func @token_no_uses() { + // CHECK: %[[TOKEN:.*]] = async.execute + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + %token = async.execute { + async.yield + } + return +} + +// CHECK-LABEL: @token_return +func @token_return() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: return %[[TOKEN]] + return %token : !async.token +} + +// CHECK-LABEL: @token_await +func @token_await() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_await_and_return +func @token_await_and_return() -> !async.token { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + // CHECK-NOT: async.drop_ref + async.await %token : !async.token + // CHECK: return %[[TOKEN]] + return %token : !async.token +} + +// CHECK-LABEL: @token_await_inside_scf_if +func @token_await_inside_scf_if(%arg0: i1) { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: scf.if %arg0 { + scf.if %arg0 { + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + } + // CHECK: } + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_conditional_await +func @token_conditional_await(%arg0: i1) { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + cond_br %arg0, ^bb1, ^bb2 +^bb1: + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + return +^bb2: + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + return +} + +// CHECK-LABEL: @token_await_in_the_loop +func @token_await_in_the_loop() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + br ^bb1 +^bb1: + // CHECK: async.await %[[TOKEN]] + async.await %token : !async.token + %0 = call @cond(): () -> (i1) + cond_br %0, ^bb1, ^bb2 +^bb2: + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + return +} + +// CHECK-LABEL: @token_defined_in_the_loop +func @token_defined_in_the_loop() { + br ^bb1 +^bb1: + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + %0 = call @cond(): () -> (i1) + cond_br %0, ^bb1, ^bb2 +^bb2: + return +} + +// CHECK-LABEL: @token_capture +func @token_capture() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK-NEXT: async.yield + async.await %token : !async.token + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_nested_capture +func @token_nested_capture() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute { + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: %[[TOKEN_1:.*]] = async.execute + %token_1 = async.execute { + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: %[[TOKEN_2:.*]] = async.execute + %token_2 = async.execute { + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_2]] {count = 1 : i32} + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_1]] {count = 1 : i32} + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.yield + } + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: return + return +} + +// CHECK-LABEL: @token_dependency +func @token_dependency() { + // CHECK: %[[TOKEN:.*]] = async.execute + %token = async.execute { + async.yield + } + + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute[%token] { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK-NEXT: async.yield + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + // CHECK: async.await %[[TOKEN_0]] + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + async.await %token_0 : !async.token + + // CHECK: return + return +} + +// CHECK-LABEL: @value_operand +func @value_operand() -> f32 { + // CHECK: %[[TOKEN:.*]], %[[RESULTS:.*]] = async.execute + %token, %results = async.execute -> !async.value { + %0 = constant 0.0 : f32 + async.yield %0 : f32 + } + + // CHECK: async.add_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.add_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: %[[TOKEN_0:.*]] = async.execute + %token_0 = async.execute[%token](%results as %arg0 : !async.value) { + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32} + // CHECK: async.yield + async.yield + } + + // CHECK: async.await %[[TOKEN]] + // CHECK: async.drop_ref %[[TOKEN]] {count = 1 : i32} + async.await %token : !async.token + + // CHECK: async.await %[[TOKEN_0]] + // CHECK: async.drop_ref %[[TOKEN_0]] {count = 1 : i32} + async.await %token_0 : !async.token + + // CHECK: async.await %[[RESULTS]] + // CHECK: async.drop_ref %[[RESULTS]] {count = 1 : i32} + %0 = async.await %results : !async.value + + // CHECK: return + return %0 : f32 +} diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir index a95be650eff78e..54dc6736b4dd9a 100644 --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -134,3 +134,17 @@ func @create_group_and_await_all(%arg0: !async.token, %arg1: !async.value) %3 = addi %1, %2 : index return %3 : index } + +// CHECK-LABEL: @add_ref +func @add_ref(%arg0: !async.token) { + // CHECK: async.add_ref %arg0 {count = 1 : i32} + async.add_ref %arg0 {count = 1 : i32} : !async.token + return +} + +// CHECK-LABEL: @drop_ref +func @drop_ref(%arg0: !async.token) { + // CHECK: async.drop_ref %arg0 {count = 1 : i32} + async.drop_ref %arg0 {count = 1 : i32} : !async.token + return +} diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir index fa471811ef4ed7..2ddc66651db20e 100644 --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -47,9 +47,7 @@ module { // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]] // CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] // CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] -// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] -// CHECK: linalg.fill(%[[SV3_2]], %[[CST]]) +// CHECK: linalg.fill(%[[SV3]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" // CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { // CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] @@ -111,12 +109,9 @@ module { // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]] // CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]] // CHECK-SAME: [%[[M]], %[[TILE_N_2]]] -// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]] // CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]] -// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] -// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]] -// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] -// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) +// CHECK-SAME: [%[[K]], %[[TILE_N]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) // CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer" // CHECK-NOT: linalg.fill // CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]] @@ -191,16 +186,11 @@ module { // CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]] // CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] -// CHECK: %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N]]] -// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] -// CHECK: %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] -// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) +// CHECK-SAME: [%[[TILE_M]], %[[K]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" -// CHECK: linalg.fill(%[[SV2_2]], %[[CST]]) +// CHECK: linalg.fill(%[[SV2]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" // CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = @@ -271,18 +261,15 @@ module { // CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]] // CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] -// CHECK: %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[K1:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M]], %[[K1]]] -// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]] -// CHECK: %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]] +// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]] // CHECK: linalg.matmul // CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" // CHECK-SAME: ins(%[[SV3]], %[[SV4]] // CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV1_2]] : memref) +// CHECK-SAME: outs(%[[SV1]] : memref) // CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = // CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir deleted file mode 100644 index a02c878ef34161..00000000000000 --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ /dev/null @@ -1,133 +0,0 @@ -// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s - -module { - func @three_op_fusion(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3 : memref) { - %cst = constant 0.000000e+00 : f32 - %c0 = constant 0 : index - %c1 = constant 1 : index - %d0 = dim %arg0, %c0 : memref - %d1 = dim %arg1, %c1 : memref - %0 = alloc(%d0, %d1) : memref - linalg.fill(%0, %cst) : memref, f32 - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%0 : memref) - linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%0, %arg2 : memref, memref) - outs(%arg3 : memref) { - ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) : - %5 = addf %arg4, %arg5 : f32 - linalg.yield %5 : f32 - } - return - } -} - -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK: func @three_op_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} { -// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]] -// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]] -// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]] -// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]] -// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}}) -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_TEMP]] : memref) -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ARG3]] : memref) -// CHECK: scf.yield -// CHECK: } - -// ----- - -module { - func @sequence_of_matmul(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3: memref, - %arg4: memref) { - %cst = constant 0.000000e+00 : f32 - %c0 = constant 0 : index - %c1 = constant 1 : index - %m = dim %arg0, %c0 : memref - %n1 = dim %arg1, %c1 : memref - %n2 = dim %arg2, %c1 : memref - %n3 = dim %arg3, %c1 : memref - %0 = alloc(%m, %n1) : memref - %1 = alloc(%m, %n2) : memref - linalg.fill(%0, %cst) : memref, f32 - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%0 : memref) - linalg.fill(%1, %cst) : memref, f32 - linalg.matmul ins(%0, %arg2 : memref, memref) - outs(%1 : memref) - linalg.fill(%arg4, %cst) : memref, f32 - linalg.matmul ins(%1, %arg3 : memref, memref) - outs(%arg4 : memref) - return - } -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK: func @sequence_of_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C16:.+]] = constant 16 : index -// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]] -// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]]) -// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]]) -// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]]) -// CHECK-SAME: step (%[[C16]]) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N2]]] -// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]] -// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] -// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]] -// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] -// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N3]]] -// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N1]]] -// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]] -// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]] -// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]] -// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] -// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref) -// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref) -// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}}) -// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ARG4]] : memref) -// CHECK: scf.yield -// CHECK: } - diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index 1e2e4a5bf11624..ebc59c8dbeac26 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -45,6 +45,20 @@ func @dim_of_tensor_load(%arg0: memref) -> index { return %1 : index } +// Test case: Folding of load(tensor_to_memref(%v, %idxs)) +// -> extract_element(%v, %idx) +// CHECK-LABEL: func @load_from_tensor_to_memref( +// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index +// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor +// CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]] +// CHECK-NOT: load +// CHECK: return %[[RES]] : f32 +func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor) -> f32 { + %0 = tensor_to_memref %arg2 : memref + %1 = load %0[%arg0, %arg1] : memref + return %1 : f32 +} + // Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx // CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements( // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index @@ -59,3 +73,25 @@ func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index { %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex> return %1 : index } + +// Test case: Folding of comparisons with equal operands. +// CHECK-LABEL: @cmpi_equal_operands +// CHECK-DAG: %[[T:.*]] = constant true +// CHECK-DAG: %[[F:.*]] = constant false +// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]], +// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]] +func @cmpi_equal_operands(%arg0: i64) + -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { + %0 = cmpi "eq", %arg0, %arg0 : i64 + %1 = cmpi "sle", %arg0, %arg0 : i64 + %2 = cmpi "sge", %arg0, %arg0 : i64 + %3 = cmpi "ule", %arg0, %arg0 : i64 + %4 = cmpi "uge", %arg0, %arg0 : i64 + %5 = cmpi "ne", %arg0, %arg0 : i64 + %6 = cmpi "slt", %arg0, %arg0 : i64 + %7 = cmpi "sgt", %arg0, %arg0 : i64 + %8 = cmpi "ult", %arg0, %arg0 : i64 + %9 = cmpi "ugt", %arg0, %arg0 : i64 + return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 + : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 +} diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp index 5289b2d1055f2a..eb9e3a53313832 100644 --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -197,44 +197,6 @@ struct TestLinalgGreedyFusion } } }; - -/// Pass to test tile and fuse of sequence of operations. Intended only for -/// testing. -struct TestLinalgTileAndFuseSequencePass - : public PassWrapper { - TestLinalgTileAndFuseSequencePass() = default; - TestLinalgTileAndFuseSequencePass( - const TestLinalgTileAndFuseSequencePass &pass){}; - - ListOption tileSizes{ - *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), - llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnFunction() override { - FuncOp funcOp = getOperation(); - auto &blocks = funcOp.getBody().getBlocks(); - if (!llvm::hasSingleElement(blocks)) { - return; - } - SmallVector linalgOps = - llvm::to_vector<2>(blocks.front().getOps()); - Aliases aliases; - LinalgDependenceGraph dependenceGraph(aliases, linalgOps); - OpBuilder builder(funcOp.getContext()); - Optional tileAndFuseOps = tileAndFuseLinalgOps( - builder, linalgOps, dependenceGraph, - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType( - LinalgTilingLoopType::ParallelLoops)); - if (!tileAndFuseOps) - return signalPassFailure(); - for (auto op : linalgOps) - op.erase(); - } -}; } // namespace namespace mlir { @@ -249,12 +211,5 @@ void registerTestLinalgGreedyFusion() { "test-linalg-greedy-fusion", "Test Linalg fusion by applying a greedy test transformation."); } -void registerTestLinalgTileAndFuseSequencePass() { - PassRegistration - testTileAndFuseSequencePass( - "test-linalg-tile-and-fuse", - "Test Linalg tiling and fusion of a sequence of Linalg operations."); -} - } // namespace test } // namespace mlir diff --git a/mlir/test/mlir-cpu-runner/async-group.mlir b/mlir/test/mlir-cpu-runner/async-group.mlir index 87004ff7b38101..50f85ff5460934 100644 --- a/mlir/test/mlir-cpu-runner/async-group.mlir +++ b/mlir/test/mlir-cpu-runner/async-group.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: mlir-opt %s -async-ref-counting \ +// RUN: -convert-async-to-llvm \ // RUN: -convert-std-to-llvm \ // RUN: | mlir-cpu-runner \ // RUN: -e main -entry-point-result=void -O0 \ diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir index fd0268e7ac5650..5f06dd17ed6183 100644 --- a/mlir/test/mlir-cpu-runner/async.mlir +++ b/mlir/test/mlir-cpu-runner/async.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -convert-async-to-llvm \ +// RUN: mlir-opt %s -async-ref-counting \ +// RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-linalg-to-llvm \ // RUN: -convert-std-to-llvm \ diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index a0e36cf82534b9..4771b11b20e42f 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -74,7 +74,6 @@ void registerTestLinalgCodegenStrategy(); void registerTestLinalgFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); -void registerTestLinalgTileAndFuseSequencePass(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); void registerTestLoopFusion(); @@ -142,7 +141,6 @@ void registerTestPasses() { test::registerTestLinalgFusionTransforms(); test::registerTestLinalgGreedyFusion(); test::registerTestLinalgHoisting(); - test::registerTestLinalgTileAndFuseSequencePass(); test::registerTestLinalgTransforms(); test::registerTestLivenessPass(); test::registerTestLoopFusion();