From 8131e190647ac2b5b085b48a6e3b48c1d7520a66 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Thu, 23 Jul 2020 18:43:40 -0700 Subject: [PATCH 1/4] [LegalizeTypes] Teach DAGTypeLegalizer::GenWidenVectorLoads to pad with undef if needed when concatenating small or loads to match a larger load In the included test case the align 16 allowed the v23f32 load to handled as load v16f32, load v4f32, and load v4f32(one element not used). These loads all need to be concatenated together into a final vector. In this case we tried to concatenate the two v4f32 loads to match the type of the v16f32 load so we could do a second concat_vectors, but those loads alone only add up to v8f32. So we need to two v4f32 undefs to pad it. It appears we've tried to hack around a similar issue in this code before by adding undef padding to loads in one of the earlier loops in this function. Originally in r147964 by padding all loads narrower than previous loads to the same size. Later modifed to only the last load in r293088. This patch removes that earlier code and just handles it on demand where we know we need it. Fixes PR46820 Differential Revision: https://reviews.llvm.org/D84463 --- .../SelectionDAG/LegalizeVectorTypes.cpp | 27 +++++------ llvm/test/CodeGen/X86/pr46820.ll | 47 +++++++++++++++++++ 2 files changed, 59 insertions(+), 15 deletions(-) create mode 100644 llvm/test/CodeGen/X86/pr46820.ll diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index b1ec3050e201d5..1394f084c6dc60 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -4912,7 +4912,8 @@ SDValue DAGTypeLegalizer::GenWidenVectorLoads(SmallVectorImpl &LdChain, int LdWidth = LdVT.getSizeInBits(); int WidthDiff = WidenWidth - LdWidth; - // Allow wider loads. + // Allow wider loads if they are sufficiently aligned to avoid memory faults + // and if the original load is simple. unsigned LdAlign = (!LD->isSimple()) ? 0 : LD->getAlignment(); // Find the vector type that can load from. @@ -4964,19 +4965,6 @@ SDValue DAGTypeLegalizer::GenWidenVectorLoads(SmallVectorImpl &LdChain, LD->getPointerInfo().getWithOffset(Offset), LD->getOriginalAlign(), MMOFlags, AAInfo); LdChain.push_back(L.getValue(1)); - if (L->getValueType(0).isVector() && NewVTWidth >= LdWidth) { - // Later code assumes the vector loads produced will be mergeable, so we - // must pad the final entry up to the previous width. Scalars are - // combined separately. - SmallVector Loads; - Loads.push_back(L); - unsigned size = L->getValueSizeInBits(0); - while (size < LdOp->getValueSizeInBits(0)) { - Loads.push_back(DAG.getUNDEF(L->getValueType(0))); - size += L->getValueSizeInBits(0); - } - L = DAG.getNode(ISD::CONCAT_VECTORS, dl, LdOp->getValueType(0), Loads); - } } else { L = DAG.getLoad(NewVT, dl, Chain, BasePtr, LD->getPointerInfo().getWithOffset(Offset), @@ -5017,8 +5005,17 @@ SDValue DAGTypeLegalizer::GenWidenVectorLoads(SmallVectorImpl &LdChain, EVT NewLdTy = LdOps[i].getValueType(); if (NewLdTy != LdTy) { // Create a larger vector. + unsigned NumOps = NewLdTy.getSizeInBits() / LdTy.getSizeInBits(); + assert(NewLdTy.getSizeInBits() % LdTy.getSizeInBits() == 0); + SmallVector WidenOps(NumOps); + unsigned j = 0; + for (; j != End-Idx; ++j) + WidenOps[j] = ConcatOps[Idx+j]; + for (; j != NumOps; ++j) + WidenOps[j] = DAG.getUNDEF(LdTy); + ConcatOps[End-1] = DAG.getNode(ISD::CONCAT_VECTORS, dl, NewLdTy, - makeArrayRef(&ConcatOps[Idx], End - Idx)); + WidenOps); Idx = End - 1; LdTy = NewLdTy; } diff --git a/llvm/test/CodeGen/X86/pr46820.ll b/llvm/test/CodeGen/X86/pr46820.ll new file mode 100644 index 00000000000000..76093801f9d0ab --- /dev/null +++ b/llvm/test/CodeGen/X86/pr46820.ll @@ -0,0 +1,47 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=avx512f | FileCheck %s + +; The alignment of 16 causes type legalization to split this as 3 loads, +; v16f32, v4f32, and v4f32. This loads 24 elements, but the load is aligned +; to 16 bytes so this i safe. There was an issue with type legalization building +; the proper concat_vectors for this because the two v4f32s don't add up to +; v16f32 and require padding. + +define <23 x float> @load23(<23 x float>* %p) { +; CHECK-LABEL: load23: +; CHECK: # %bb.0: +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: vmovups 64(%rsi), %ymm0 +; CHECK-NEXT: vmovups (%rsi), %zmm1 +; CHECK-NEXT: vmovaps 64(%rsi), %xmm2 +; CHECK-NEXT: vmovss {{.*#+}} xmm3 = mem[0],zero,zero,zero +; CHECK-NEXT: vmovss %xmm3, 88(%rdi) +; CHECK-NEXT: vmovaps %xmm2, 64(%rdi) +; CHECK-NEXT: vmovaps %zmm1, (%rdi) +; CHECK-NEXT: vextractf128 $1, %ymm0, %xmm0 +; CHECK-NEXT: vmovlps %xmm0, 80(%rdi) +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %t0 = load <23 x float>, <23 x float>* %p, align 16 + ret <23 x float> %t0 +} + +; Same test as above with minimal alignment just to demonstrate the different +; codegen. +define <23 x float> @load23_align_1(<23 x float>* %p) { +; CHECK-LABEL: load23_align_1: +; CHECK: # %bb.0: +; CHECK-NEXT: movq %rdi, %rax +; CHECK-NEXT: vmovups (%rsi), %zmm0 +; CHECK-NEXT: vmovups 64(%rsi), %xmm1 +; CHECK-NEXT: movq 80(%rsi), %rcx +; CHECK-NEXT: vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero +; CHECK-NEXT: vmovss %xmm2, 88(%rdi) +; CHECK-NEXT: movq %rcx, 80(%rdi) +; CHECK-NEXT: vmovaps %xmm1, 64(%rdi) +; CHECK-NEXT: vmovaps %zmm0, (%rdi) +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %t0 = load <23 x float>, <23 x float>* %p, align 1 + ret <23 x float> %t0 +} From d054c7ee2e9f4f98af7f22a5b00a941eb919bd59 Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Thu, 23 Jul 2020 19:13:16 -0700 Subject: [PATCH 2/4] Add test utility 'extract' See https://lists.llvm.org/pipermail/llvm-dev/2020-July/143373.html "[llvm-dev] Multiple documents in one test file" for some discussions. `extract part filename` splits the input file into multiple parts separated by regex `^(.|//)--- ` and extract the specified part to stdout or the output file (if specified). Use case A (organizing input of different formats (e.g. linker script+assembly) in one file). ``` // RUN: extract lds %s -o %t.lds // RUN: extract asm %s -o %t.s // RUN: llvm-mc %t.s -o %t.o // RUN: ld.lld -T %t.lds %t.o -o %t This is sometimes better than the %S/Inputs/ approach because the user can see the auxiliary files immediately and don't have to open another file. ``` Use case B (for utilities which don't have built-in input splitting feature): ``` // RUN: extract case1 %s | llc | FileCheck %s --check-prefix=CASE1 // RUN: extract case2 %s | llc | FileCheck %s --check-prefix=CASE2 Combing tests prudently can improve readability. This is sometimes better than having multiple test files. ``` Since this is a new utility, there is no git history concerns for UpperCase variable names. I use lowerCase variable names like mlir/lld. Reviewed By: jhenderson Differential Revision: https://reviews.llvm.org/D83834 --- lld/test/CMakeLists.txt | 2 +- lld/test/ELF/linkerscript/noload.s | 19 +-- lld/test/lit.cfg.py | 8 +- llvm/docs/TestingGuide.rst | 23 +++- llvm/test/CMakeLists.txt | 1 + llvm/test/lit.cfg.py | 1 + llvm/test/tools/extract/Inputs/basic-aa.txt | 6 + llvm/test/tools/extract/Inputs/basic-bb.txt | 10 ++ llvm/test/tools/extract/basic.test | 32 +++++ llvm/test/tools/extract/help.test | 5 + llvm/test/tools/extract/no-leading-lines.test | 10 ++ llvm/test/tools/gold/X86/multiple-sections.ll | 14 ++- .../tools/llvm-objcopy/ELF/strip-symbol.test | 19 +-- llvm/test/tools/llvm-strings/radix.test | 23 ++-- llvm/tools/extract/.clang-tidy | 19 +++ llvm/tools/extract/CMakeLists.txt | 7 ++ llvm/tools/extract/extract.cpp | 113 ++++++++++++++++++ 17 files changed, 276 insertions(+), 36 deletions(-) create mode 100644 llvm/test/tools/extract/Inputs/basic-aa.txt create mode 100644 llvm/test/tools/extract/Inputs/basic-bb.txt create mode 100644 llvm/test/tools/extract/basic.test create mode 100644 llvm/test/tools/extract/help.test create mode 100644 llvm/test/tools/extract/no-leading-lines.test create mode 100644 llvm/tools/extract/.clang-tidy create mode 100644 llvm/tools/extract/CMakeLists.txt create mode 100644 llvm/tools/extract/extract.cpp diff --git a/lld/test/CMakeLists.txt b/lld/test/CMakeLists.txt index 4fbd2534b5a977..7831bb1a8de063 100644 --- a/lld/test/CMakeLists.txt +++ b/lld/test/CMakeLists.txt @@ -34,7 +34,7 @@ configure_lit_site_cfg( set(LLD_TEST_DEPS lld) if (NOT LLD_BUILT_STANDALONE) list(APPEND LLD_TEST_DEPS - FileCheck count llc llvm-ar llvm-as llvm-bcanalyzer llvm-config llvm-cvtres + FileCheck count extract llc llvm-ar llvm-as llvm-bcanalyzer llvm-config llvm-cvtres llvm-dis llvm-dwarfdump llvm-lib llvm-lipo llvm-mc llvm-nm llvm-objcopy llvm-objdump llvm-pdbutil llvm-readelf llvm-readobj llvm-strip not obj2yaml opt yaml2obj diff --git a/lld/test/ELF/linkerscript/noload.s b/lld/test/ELF/linkerscript/noload.s index 2f52b465854e28..c2014722985d3d 100644 --- a/lld/test/ELF/linkerscript/noload.s +++ b/lld/test/ELF/linkerscript/noload.s @@ -1,11 +1,7 @@ # REQUIRES: x86 -# RUN: llvm-mc -filetype=obj -triple=x86_64-unknown-linux %s -o %t.o -# RUN: echo "SECTIONS { \ -# RUN: .data_noload_a (NOLOAD) : { *(.data_noload_a) } \ -# RUN: .data_noload_b (0x10000) (NOLOAD) : { *(.data_noload_b) } \ -# RUN: .no_input_sec_noload (NOLOAD) : { . += 1; } \ -# RUN: .text (0x20000) : { *(.text) } };" > %t.script -# RUN: ld.lld -o %t --script %t.script %t.o +# RUN: extract asm %s -o %t.s && extract lds %s -o %t.lds +# RUN: llvm-mc -filetype=obj -triple=x86_64 %t.s -o %t.o +# RUN: ld.lld -o %t --script %t.lds %t.o # RUN: llvm-readelf -S -l %t | FileCheck %s # CHECK: Name Type Address Off Size @@ -16,6 +12,7 @@ # CHECK: Type Offset VirtAddr PhysAddr # CHECK-NEXT: LOAD 0x001000 0x0000000000020000 0x0000000000020000 +#--- asm .section .text,"ax",@progbits nop @@ -24,3 +21,11 @@ .section .data_noload_b,"aw",@progbits .zero 4096 + +#--- lds +SECTIONS { + .data_noload_a (NOLOAD) : { *(.data_noload_a) } + .data_noload_b (0x10000) (NOLOAD) : { *(.data_noload_b) } + .no_input_sec_noload (NOLOAD) : { . += 1; } + .text (0x20000) : { *(.text) } +} diff --git a/lld/test/lit.cfg.py b/lld/test/lit.cfg.py index 267f8c5178584a..0fa9b48c3c792d 100644 --- a/lld/test/lit.cfg.py +++ b/lld/test/lit.cfg.py @@ -39,9 +39,9 @@ llvm_config.use_lld() tool_patterns = [ - 'llc', 'llvm-as', 'llvm-mc', 'llvm-nm', 'llvm-objdump', 'llvm-pdbutil', - 'llvm-dwarfdump', 'llvm-readelf', 'llvm-readobj', 'obj2yaml', 'yaml2obj', - 'opt', 'llvm-dis'] + 'extract', 'llc', 'llvm-as', 'llvm-mc', 'llvm-nm', 'llvm-objdump', + 'llvm-pdbutil', 'llvm-dwarfdump', 'llvm-readelf', 'llvm-readobj', + 'obj2yaml', 'yaml2obj', 'opt', 'llvm-dis'] llvm_config.add_tool_substitutions(tool_patterns) @@ -87,7 +87,7 @@ # Indirectly check if the mt.exe Microsoft utility exists by searching for # cvtres, which always accompanies it. Alternatively, check if we can use # libxml2 to merge manifests. -if (lit.util.which('cvtres', config.environment['PATH']) or +if (lit.util.which('cvtres', config.environment['PATH']) or config.llvm_libxml2_enabled): config.available_features.add('manifest_tool') diff --git a/llvm/docs/TestingGuide.rst b/llvm/docs/TestingGuide.rst index 2e937f00062728..6fd9ab2d24ca40 100644 --- a/llvm/docs/TestingGuide.rst +++ b/llvm/docs/TestingGuide.rst @@ -271,8 +271,27 @@ adding your code there instead of creating a new file. Extra files ----------- -If your test requires extra files besides the file containing the ``RUN:`` -lines, the idiomatic place to put them is in a subdirectory ``Inputs``. +If your test requires extra files besides the file containing the ``RUN:`` lines +and the extra files are small, consider specifying them in the same file and +using ``extract`` to extract them. For example, + +.. code-block:: llvm + + ; RUN: extract b %s -o %tb.ll + ; RUN: extract a %s | llvm-link - %tb.ll -S | FileCheck %s + + ; CHECK: ... + + ;--- a + ... + ;--- b + ... + +The parts are separated by the regex ``^(.|//)--- ``. By default the +extracted content has leading empty lines to preserve line numbers. Specify +``--no-leading-lines`` to drop leading lines. + +If the extra files are large, the idiomatic place to put them is in a subdirectory ``Inputs``. You can then refer to the extra files as ``%S/Inputs/foo.bar``. For example, consider ``test/Linker/ident.ll``. The directory structure is diff --git a/llvm/test/CMakeLists.txt b/llvm/test/CMakeLists.txt index 6994c29efa9a5a..529c06c82b2476 100644 --- a/llvm/test/CMakeLists.txt +++ b/llvm/test/CMakeLists.txt @@ -52,6 +52,7 @@ set(LLVM_TEST_DEPENDS UnitTests bugpoint count + extract llc lli lli-child-target diff --git a/llvm/test/lit.cfg.py b/llvm/test/lit.cfg.py index 0a3289fcc4ad4c..49bd8ddfb2dc58 100644 --- a/llvm/test/lit.cfg.py +++ b/llvm/test/lit.cfg.py @@ -130,6 +130,7 @@ def get_asan_rtlib(): config.llvm_locstats_used = os.path.exists(llvm_locstats_tool) tools = [ + ToolSubst('%extract', FindTool('extract')), ToolSubst('%lli', FindTool('lli'), post='.', extra_args=lli_args), ToolSubst('%llc_dwarf', FindTool('llc'), extra_args=llc_args), ToolSubst('%go', config.go_executable, unresolved='ignore'), diff --git a/llvm/test/tools/extract/Inputs/basic-aa.txt b/llvm/test/tools/extract/Inputs/basic-aa.txt new file mode 100644 index 00000000000000..9eac3fdccbee43 --- /dev/null +++ b/llvm/test/tools/extract/Inputs/basic-aa.txt @@ -0,0 +1,6 @@ + + + +aa +; BB-NOT: {{.}} +; BB: {{^}}bb{{$}} diff --git a/llvm/test/tools/extract/Inputs/basic-bb.txt b/llvm/test/tools/extract/Inputs/basic-bb.txt new file mode 100644 index 00000000000000..de17efab6fb6b6 --- /dev/null +++ b/llvm/test/tools/extract/Inputs/basic-bb.txt @@ -0,0 +1,10 @@ + + + + + + + +bb + +// CC: // Comments are preserved. diff --git a/llvm/test/tools/extract/basic.test b/llvm/test/tools/extract/basic.test new file mode 100644 index 00000000000000..9f9413106cc752 --- /dev/null +++ b/llvm/test/tools/extract/basic.test @@ -0,0 +1,32 @@ +# AA-NOT: {{.}} +# AA: {{^}}aa{{$}} +#--- aa +aa +; BB-NOT: {{.}} +; BB: {{^}}bb{{$}} +;--- bb +bb + +// CC: // Comments are preserved. +//--- cc +cc +// Comments are preserved. +;--- dup +;--- dup + +# RUN: extract aa %s | diff %S/Inputs/basic-aa.txt - +# RUN: extract bb - < %s | diff %S/Inputs/basic-bb.txt - +# RUN: extract cc %s -o %t +# RUN: FileCheck %s --check-prefix=CC < %t + +# RUN: not %extract aa 2>&1 | FileCheck %s --check-prefix=NO_INPUT + +# NO_INPUT: extract: error: input filename is not specified + +# RUN: not %extract dup %s 2>&1 | FileCheck %s --check-prefix=DUP + +# DUP: extract: error: {{.*}}.test: ';--- dup' occurs more than once + +# RUN: not %extract not_exist %s 2>&1 | FileCheck %s --check-prefix=NOT_EXIST + +# NOT_EXIST: extract: error: {{.*}}.test: ';--- not_exist' was not found diff --git a/llvm/test/tools/extract/help.test b/llvm/test/tools/extract/help.test new file mode 100644 index 00000000000000..282052869116c2 --- /dev/null +++ b/llvm/test/tools/extract/help.test @@ -0,0 +1,5 @@ +RUN: extract --help 2>&1 | FileCheck --implicit-check-not='General Options:' %s +CHECK: OVERVIEW: Split input {{.*}} +CHECK: Generic Options: +CHECK: extract Options: +CHECK: -o diff --git a/llvm/test/tools/extract/no-leading-lines.test b/llvm/test/tools/extract/no-leading-lines.test new file mode 100644 index 00000000000000..f0efff5475afb4 --- /dev/null +++ b/llvm/test/tools/extract/no-leading-lines.test @@ -0,0 +1,10 @@ +## With --no-leading-lines, don't add leading lines (which is used to preserve line numbers). + +# RUN: extract --no-leading-lines input %s -o %t +# RUN: count 1 < %t +# RUN: FileCheck %s < %t + +# CHECK: input + +#--- input +input diff --git a/llvm/test/tools/gold/X86/multiple-sections.ll b/llvm/test/tools/gold/X86/multiple-sections.ll index facbd8d992ed78..31a89a9d3b4844 100644 --- a/llvm/test/tools/gold/X86/multiple-sections.ll +++ b/llvm/test/tools/gold/X86/multiple-sections.ll @@ -1,10 +1,8 @@ -; RUN: echo ".text.tin" > %t_order_lto.txt -; RUN: echo ".text._start" >> %t_order_lto.txt -; RUN: echo ".text.pat" >> %t_order_lto.txt -; RUN: llvm-as %s -o %t.o +; RUN: extract order %s -o %t.order +; RUN: extract ir %s | llvm-as -o %t.o ; RUN: %gold -plugin %llvmshlibdir/LLVMgold%shlibext \ ; RUN: -m elf_x86_64 -o %t.exe %t.o \ -; RUN: --section-ordering-file=%t_order_lto.txt +; RUN: --section-ordering-file=%t.order ; RUN: llvm-readelf -s %t.exe | FileCheck %s ; Check that the order of the sections is tin -> _start -> pat. @@ -13,6 +11,12 @@ ; CHECK: 00000000004000b0 1 FUNC LOCAL DEFAULT 1 tin ; CHECK: 00000000004000c0 15 FUNC GLOBAL DEFAULT 1 _start +;--- order +.text.tin +.text._start +.text.pat + +;--- ir target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-linux-gnu" diff --git a/llvm/test/tools/llvm-objcopy/ELF/strip-symbol.test b/llvm/test/tools/llvm-objcopy/ELF/strip-symbol.test index 78de46cc47b5d4..ad71e81eab8309 100644 --- a/llvm/test/tools/llvm-objcopy/ELF/strip-symbol.test +++ b/llvm/test/tools/llvm-objcopy/ELF/strip-symbol.test @@ -1,19 +1,24 @@ -# RUN: yaml2obj %s -o %t +# RUN: extract yaml %s | yaml2obj - -o %t # RUN: llvm-objcopy --strip-symbol baz -N bar %t %t2 # RUN: llvm-readobj --symbols --sections %t2 | FileCheck %s # RUN: llvm-strip --strip-symbol baz -N bar %t -o %t3 # RUN: cmp %t2 %t3 # RUN: llvm-strip --regex --strip-symbol '^b.*' -N bar %t -o %t4 # RUN: cmp %t3 %t4 -# RUN: echo " bar # bar" > %t-list.txt -# RUN: echo " baz # baz" >> %t-list.txt -# RUN: echo " # no symbol" >> %t-list.txt -# RUN: llvm-objcopy --strip-symbols %t-list.txt %t %t5 +# RUN: extract list1 %s -o %t-list.txt && llvm-objcopy --strip-symbols %t-list.txt %t %t5 # RUN: cmp %t3 %t5 -# RUN: echo "b.* # bar & baz" > %t-list2.txt -# RUN: llvm-objcopy --regex --strip-symbols %t-list2.txt %t %t6 +# RUN: extract list2 %s -o %t-list2.txt && llvm-objcopy --regex --strip-symbols %t-list2.txt %t %t6 # RUN: cmp %t3 %t6 +#--- list1 +bar # bar +baz # baz +# no symbol + +#--- list2 +b.* # bar & baz + +#--- yaml !ELF FileHeader: Class: ELFCLASS64 diff --git a/llvm/test/tools/llvm-strings/radix.test b/llvm/test/tools/llvm-strings/radix.test index d23fb3cddc8f80..d9796a937d9050 100644 --- a/llvm/test/tools/llvm-strings/radix.test +++ b/llvm/test/tools/llvm-strings/radix.test @@ -1,15 +1,18 @@ ## Show that llvm-strings can handle the -t/--radix switch properly. -RUN: echo one > %t -RUN: echo two >> %t -RUN: echo three >> %t -RUN: echo four >> %t -RUN: echo five >> %t -RUN: echo six >> %t -RUN: echo seven >> %t -RUN: echo eight >> %t -RUN: echo nine >> %t -RUN: echo ten >> %t +RUN: extract --no-leading-lines input %s -o %t +#--- input +one +two +three +four +five +six +seven +eight +nine +ten +#--- end RUN: llvm-strings %t | FileCheck %s -check-prefix CHECK-NONE --implicit-check-not={{.}} RUN: llvm-strings -t d %t | FileCheck %s -check-prefix CHECK-DEC --strict-whitespace --implicit-check-not={{.}} diff --git a/llvm/tools/extract/.clang-tidy b/llvm/tools/extract/.clang-tidy new file mode 100644 index 00000000000000..87ec2ff53af6e8 --- /dev/null +++ b/llvm/tools/extract/.clang-tidy @@ -0,0 +1,19 @@ +# Almost identical to the top-level .clang-tidy, except that {Member,Parameter,Variable}Case use camelBack. +Checks: '-*,clang-diagnostic-*,llvm-*,misc-*,-misc-unused-parameters,-misc-non-private-member-variables-in-classes,readability-identifier-naming' +CheckOptions: + - key: readability-identifier-naming.ClassCase + value: CamelCase + - key: readability-identifier-naming.EnumCase + value: CamelCase + - key: readability-identifier-naming.FunctionCase + value: camelBack + - key: readability-identifier-naming.MemberCase + value: camelBack + - key: readability-identifier-naming.ParameterCase + value: camelBack + - key: readability-identifier-naming.UnionCase + value: CamelCase + - key: readability-identifier-naming.VariableCase + value: camelBack + - key: readability-identifier-naming.IgnoreMainLikeFunctions + value: 1 diff --git a/llvm/tools/extract/CMakeLists.txt b/llvm/tools/extract/CMakeLists.txt new file mode 100644 index 00000000000000..dae1f463f06669 --- /dev/null +++ b/llvm/tools/extract/CMakeLists.txt @@ -0,0 +1,7 @@ +set(LLVM_LINK_COMPONENTS + Support + ) + +add_llvm_tool(extract + extract.cpp + ) diff --git a/llvm/tools/extract/extract.cpp b/llvm/tools/extract/extract.cpp new file mode 100644 index 00000000000000..8ccb5391561455 --- /dev/null +++ b/llvm/tools/extract/extract.cpp @@ -0,0 +1,113 @@ +//===- extract.cpp - Input splitting utility ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Split input into multipe parts separated by regex '^(.|//)--- ' and extract +// the specified part. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileOutputBuffer.h" +#include "llvm/Support/LineIterator.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/WithColor.h" +#include + +using namespace llvm; + +static cl::OptionCategory cat("extract Options"); + +static cl::opt part(cl::Positional, cl::desc("part"), + cl::cat(cat)); + +static cl::opt input(cl::Positional, cl::desc("filename"), + cl::cat(cat)); + +static cl::opt output("o", cl::desc("Output filename"), + cl::value_desc("filename"), cl::init("-"), + cl::cat(cat)); + +static cl::opt noLeadingLines("no-leading-lines", + cl::desc("Don't preserve line numbers"), + cl::cat(cat)); + +static StringRef toolName; + +LLVM_ATTRIBUTE_NORETURN static void error(StringRef filename, + const Twine &message) { + if (filename.empty()) + WithColor::error(errs(), toolName) << message << '\n'; + else + WithColor::error(errs(), toolName) << filename << ": " << message << '\n'; + exit(1); +} + +static void handle(MemoryBuffer &inputBuf, StringRef input) { + const char *partBegin = nullptr, *partEnd = nullptr; + int numEmptyLines = 0; + StringRef separator; + for (line_iterator i(inputBuf, /*SkipBlanks=*/false, '\0'); !i.is_at_eof();) { + StringRef line = *i++; + size_t markerLen = line.startswith("//") ? 6 : 5; + if (!(line.size() > markerLen && + line.substr(markerLen - 4).startswith("--- "))) + continue; + separator = line.substr(0, markerLen); + StringRef cur = line.substr(markerLen); + if (cur == part) { + if (partBegin) + error(input, "'" + separator + cur + "' occurs more than once"); + if (!noLeadingLines) + numEmptyLines = i.line_number() - 1; + if (i.is_at_eof()) + break; + partBegin = i->data(); + } else if (partBegin && !partEnd) { + partEnd = line.data(); + } + } + if (!partBegin) + error(input, "'" + separator + part + "' was not found"); + if (!partEnd) + partEnd = inputBuf.getBufferEnd(); + + Expected> outputBuf = + FileOutputBuffer::create(output, numEmptyLines + (partEnd - partBegin)); + if (!outputBuf) + error(input, toString(outputBuf.takeError())); + uint8_t *buf = (*outputBuf)->getBufferStart(); + + // If --no-leading-lines is not specified, numEmptyLines is 0. Append newlines + // so that the extracted part preserves line numbers. + std::fill_n(buf, numEmptyLines, '\n'); + std::copy(partBegin, partEnd, buf + numEmptyLines); + if (Error e = (*outputBuf)->commit()) + error(input, toString(std::move(e))); +} + +int main(int argc, const char **argv) { + toolName = sys::path::stem(argv[0]); + cl::HideUnrelatedOptions({&cat}); + cl::ParseCommandLineOptions( + argc, argv, + "Split input into multiple parts separated by regex '^(.|//)--- ' and " + "extract the part specified by '^(.|//)--- '\n", + nullptr, + /*EnvVar=*/nullptr, + /*LongOptionsUseDoubleDash=*/true); + + if (input.empty()) + error("", "input filename is not specified"); + ErrorOr> bufferOrErr = + MemoryBuffer::getFileOrSTDIN(input); + if (std::error_code ec = bufferOrErr.getError()) + error(input, ec.message()); + handle(**bufferOrErr, input); +} From ab73b6da95750164daac4cfbd351ca96e1084117 Mon Sep 17 00:00:00 2001 From: Nico Weber Date: Thu, 23 Jul 2020 22:28:00 -0400 Subject: [PATCH 3/4] [gn build] (manually) merge d054c7ee2e9 --- llvm/utils/gn/secondary/lld/test/BUILD.gn | 1 + llvm/utils/gn/secondary/llvm/test/BUILD.gn | 1 + llvm/utils/gn/secondary/llvm/tools/extract/BUILD.gn | 4 ++++ 3 files changed, 6 insertions(+) create mode 100644 llvm/utils/gn/secondary/llvm/tools/extract/BUILD.gn diff --git a/llvm/utils/gn/secondary/lld/test/BUILD.gn b/llvm/utils/gn/secondary/lld/test/BUILD.gn index 96a6b07d39c1b1..dac50890d4ac9f 100644 --- a/llvm/utils/gn/secondary/lld/test/BUILD.gn +++ b/llvm/utils/gn/secondary/lld/test/BUILD.gn @@ -78,6 +78,7 @@ group("test") { ":lit_unit_site_cfg", "//lld/tools/lld:symlinks", "//lld/unittests", + "//llvm/tools/extract", "//llvm/tools/llc", "//llvm/tools/llvm-ar:symlinks", "//llvm/tools/llvm-as", diff --git a/llvm/utils/gn/secondary/llvm/test/BUILD.gn b/llvm/utils/gn/secondary/llvm/test/BUILD.gn index 550793839183f2..3e72ef3c3a446c 100644 --- a/llvm/utils/gn/secondary/llvm/test/BUILD.gn +++ b/llvm/utils/gn/secondary/llvm/test/BUILD.gn @@ -203,6 +203,7 @@ group("test") { "//llvm/lib/Testing/Support", "//llvm/tools/bugpoint", "//llvm/tools/dsymutil", + "//llvm/tools/extract", "//llvm/tools/llc", "//llvm/tools/lli", "//llvm/tools/lli/ChildTarget:lli-child-target", diff --git a/llvm/utils/gn/secondary/llvm/tools/extract/BUILD.gn b/llvm/utils/gn/secondary/llvm/tools/extract/BUILD.gn new file mode 100644 index 00000000000000..f4553476f1e8b7 --- /dev/null +++ b/llvm/utils/gn/secondary/llvm/tools/extract/BUILD.gn @@ -0,0 +1,4 @@ +executable("extract") { + deps = [ "//llvm/lib/Support" ] + sources = [ "extract.cpp" ] +} From 4589dd924dfc43c846652b85825e291af0d7428a Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 23 Jul 2020 19:38:30 -0700 Subject: [PATCH 4/4] [mlir][DialectConversion] Enable deeper integration of type conversions This revision adds support for much deeper type conversion integration into the conversion process, and enables auto-generating cast operations when necessary. Type conversions are now largely automatically managed by the conversion infra when using a ConversionPattern with a provided TypeConverter. This removes the need for patterns to do type cast wrapping themselves and moves the burden to the infra. This makes it much easier to perform partial lowerings when type conversions are involved, as any lingering type conversions will be automatically resolved/legalized by the conversion infra. To support this new integration, a few changes have been made to the type materialization API on TypeConverter. Materialization has been split into three separate categories: * Argument Materialization: This type of materialization is used when converting the type of block arguments when calling `convertRegionTypes`. This is useful for contextually inserting additional conversion operations when converting a block argument type, such as when converting the types of a function signature. * Source Materialization: This type of materialization is used to convert a legal type of the converter into a non-legal type, generally a source type. This may be called when uses of a non-legal type persist after the conversion process has finished. * Target Materialization: This type of materialization is used to convert a non-legal, or source, type into a legal, or target, type. This type of materialization is used when applying a pattern on an operation, but the types of the operands have not yet been converted. Differential Revision: https://reviews.llvm.org/D82831 --- .../mlir/Transforms/DialectConversion.h | 81 +++- .../StandardToLLVM/StandardToLLVM.cpp | 41 +- .../Transforms/LowerABIAttributesPass.cpp | 10 + mlir/lib/IR/Value.cpp | 6 +- mlir/lib/Transforms/DialectConversion.cpp | 447 +++++++++++++++--- .../StandardToLLVM/standard-to-llvm.mlir | 12 - .../SPIRV/Transforms/abi-load-store.mlir | 9 +- .../test-legalize-type-conversion.mlir | 64 +++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 116 ++++- 9 files changed, 654 insertions(+), 132 deletions(-) create mode 100644 mlir/test/Transforms/test-legalize-type-conversion.mlir diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 26b7ce6ea6c35c..8bffb9649d1f84 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -113,20 +113,40 @@ class TypeConverter { /// Register a materialization function, which must be convertible to the /// following form: - /// `Optional(PatternRewriter &, T, ValueRange, Location)`, + /// `Optional(OpBuilder &, T, ValueRange, Location)`, /// where `T` is any subclass of `Type`. This function is responsible for - /// creating an operation, using the PatternRewriter and Location provided, - /// that "casts" a range of values into a single value of the given type `T`. - /// It must return a Value of the converted type on success, an `llvm::None` - /// if it failed but other materialization can be attempted, and `nullptr` on + /// creating an operation, using the OpBuilder and Location provided, that + /// "casts" a range of values into a single value of the given type `T`. It + /// must return a Value of the converted type on success, an `llvm::None` if + /// it failed but other materialization can be attempted, and `nullptr` on /// unrecoverable failure. It will only be called for (sub)types of `T`. /// Materialization functions must be provided when a type conversion /// results in more than one type, or if a type conversion may persist after /// the conversion has finished. + /// + /// This method registers a materialization that will be called when + /// converting an illegal block argument type, to a legal type. template ::template arg_t<1>> - void addMaterialization(FnT &&callback) { - registerMaterialization( + void addArgumentMaterialization(FnT &&callback) { + argumentMaterializations.emplace_back( + wrapMaterialization(std::forward(callback))); + } + /// This method registers a materialization that will be called when + /// converting a legal type to an illegal source type. This is used when + /// conversions to an illegal type must persist beyond the main conversion. + template ::template arg_t<1>> + void addSourceMaterialization(FnT &&callback) { + sourceMaterializations.emplace_back( + wrapMaterialization(std::forward(callback))); + } + /// This method registers a materialization that will be called when + /// converting type from an illegal, or source, type to a legal type. + template ::template arg_t<1>> + void addTargetMaterialization(FnT &&callback) { + targetMaterializations.emplace_back( wrapMaterialization(std::forward(callback))); } @@ -182,9 +202,24 @@ class TypeConverter { Optional convertBlockSignature(Block *block); /// Materialize a conversion from a set of types into one result type by - /// generating a cast operation of some kind. - Value materializeConversion(PatternRewriter &rewriter, Location loc, - Type resultType, ValueRange inputs); + /// generating a cast sequence of some kind. See the respective + /// `add*Materialization` for more information on the context for these + /// methods. + Value materializeArgumentConversion(OpBuilder &builder, Location loc, + Type resultType, ValueRange inputs) { + return materializeConversion(argumentMaterializations, builder, loc, + resultType, inputs); + } + Value materializeSourceConversion(OpBuilder &builder, Location loc, + Type resultType, ValueRange inputs) { + return materializeConversion(sourceMaterializations, builder, loc, + resultType, inputs); + } + Value materializeTargetConversion(OpBuilder &builder, Location loc, + Type resultType, ValueRange inputs) { + return materializeConversion(targetMaterializations, builder, loc, + resultType, inputs); + } private: /// The signature of the callback used to convert a type. If the new set of @@ -193,8 +228,15 @@ class TypeConverter { using ConversionCallbackFn = std::function(Type, SmallVectorImpl &)>; - using MaterializationCallbackFn = std::function( - PatternRewriter &, Type, ValueRange, Location)>; + /// The signature of the callback used to materialize a conversion. + using MaterializationCallbackFn = + std::function(OpBuilder &, Type, ValueRange, Location)>; + + /// Attempt to materialize a conversion using one of the provided + /// materialization functions. + Value materializeConversion( + MutableArrayRef materializations, + OpBuilder &builder, Location loc, Type resultType, ValueRange inputs); /// Generate a wrapper for the given callback. This allows for accepting /// different callback forms, that all compose into a single version. @@ -240,24 +282,21 @@ class TypeConverter { template MaterializationCallbackFn wrapMaterialization(FnT &&callback) { return [callback = std::forward(callback)]( - PatternRewriter &rewriter, Type resultType, ValueRange inputs, + OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Optional { if (T derivedType = resultType.dyn_cast()) - return callback(rewriter, derivedType, inputs, loc); + return callback(builder, derivedType, inputs, loc); return llvm::None; }; } - /// Register a materialization. - void registerMaterialization(MaterializationCallbackFn &&callback) { - materializations.emplace_back(std::move(callback)); - } - /// The set of registered conversion functions. SmallVector conversions; /// The list of registered materialization functions. - SmallVector materializations; + SmallVector argumentMaterializations; + SmallVector sourceMaterializations; + SmallVector targetMaterializations; /// A set of cached conversions to avoid recomputing in the common case. /// Direct 1-1 conversions are the most common, so this cache stores the @@ -325,7 +364,7 @@ class ConversionPattern : public RewritePattern { protected: /// An optional type converter for use by this pattern. - TypeConverter *typeConverter; + TypeConverter *typeConverter = nullptr; private: using RewritePattern::rewrite; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 91a4867ad30754..080264e666cfe0 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -150,19 +150,42 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, // Materialization for memrefs creates descriptor structs from individual // values constituting them, when descriptors are used, i.e. more than one // value represents a memref. - addMaterialization([&](PatternRewriter &rewriter, - UnrankedMemRefType resultType, ValueRange inputs, - Location loc) -> Optional { + addArgumentMaterialization( + [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, + inputs); + }); + addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, + ValueRange inputs, + Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; - return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType, - inputs); + return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); }); - addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType, - ValueRange inputs, Location loc) -> Optional { - if (inputs.size() == 1) + // Add generic source and target materializations to handle cases where + // non-LLVM types persist after an LLVM conversion. + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() != 1) + return llvm::None; + // FIXME: These should check LLVM::DialectCastOp can actually be constructed + // from the input and result. + return builder.create(loc, resultType, inputs[0]) + .getResult(); + }); + addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() != 1) return llvm::None; - return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs); + // FIXME: These should check LLVM::DialectCastOp can actually be constructed + // from the input and result. + return builder.create(loc, resultType, inputs[0]) + .getResult(); }); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index be1d2714139018..aa376993ae7185 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -222,6 +222,16 @@ void LowerABIAttributesPass::runOnOperation() { spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module)); SPIRVTypeConverter typeConverter(targetEnv); + + // Insert a bitcast in the case of a pointer type change. + typeConverter.addSourceMaterialization([](OpBuilder &builder, + spirv::PointerType type, + ValueRange inputs, Location loc) { + if (inputs.size() != 1 || !inputs[0].getType().isa()) + return Value(); + return builder.create(loc, type, inputs[0]).getResult(); + }); + OwningRewritePatternList patterns; patterns.insert(context, typeConverter); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 6467a7f2295b3d..776b32a73d58fb 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -77,7 +77,11 @@ Operation *Value::getDefiningOp() const { Location Value::getLoc() const { if (auto *op = getDefiningOp()) return op->getLoc(); - return UnknownLoc::get(getContext()); + + // Use the location of the parent operation if this is a block argument. + // TODO: Should we just add locations to block arguments? + Operation *parentOp = cast().getOwner()->getParentOp(); + return parentOp ? parentOp->getLoc() : UnknownLoc::get(getContext()); } /// Return the Region in which this Value is defined. diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index b9ed64f573f228..9778958a458851 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" using namespace mlir; @@ -106,8 +107,15 @@ namespace { /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { /// Lookup a mapped value within the map. If a mapping for the provided value - /// does not exist then return the provided value. - Value lookupOrDefault(Value from) const; + /// does not exist then return the provided value. If `desiredType` is + /// non-null, returns the most recently mapped value with that type. If an + /// operand of that type does not exist, defaults to normal behavior. + Value lookupOrDefault(Value from, Type desiredType = nullptr) const; + + /// Lookup a mapped value within the map, or return null if a mapping does not + /// exist. If a mapping exists, this follows the same behavior of + /// `lookupOrDefault`. + Value lookupOrNull(Value from) const; /// Map a value to the one provided. void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); } @@ -121,14 +129,36 @@ struct ConversionValueMapping { }; } // end anonymous namespace -/// Lookup a mapped value within the map. If a mapping for the provided value -/// does not exist then return the provided value. -Value ConversionValueMapping::lookupOrDefault(Value from) const { - // If this value had a valid mapping, unmap that value as well in the case - // that it was also replaced. - while (auto mappedValue = mapping.lookupOrNull(from)) +Value ConversionValueMapping::lookupOrDefault(Value from, + Type desiredType) const { + // If there was no desired type, simply find the leaf value. + if (!desiredType) { + // If this value had a valid mapping, unmap that value as well in the case + // that it was also replaced. + while (auto mappedValue = mapping.lookupOrNull(from)) + from = mappedValue; + return from; + } + + // Otherwise, try to find the deepest value that has the desired type. + Value desiredValue; + do { + if (from.getType() == desiredType) + desiredValue = from; + + Value mappedValue = mapping.lookupOrNull(from); + if (!mappedValue) + break; from = mappedValue; - return from; + } while (true); + + // If the desired value was found use it, otherwise default to the leaf value. + return desiredValue ? desiredValue : from; +} + +Value ConversionValueMapping::lookupOrNull(Value from) const { + Value result = lookupOrDefault(from); + return result == from ? nullptr : result; } //===----------------------------------------------------------------------===// @@ -209,10 +239,17 @@ struct ArgConverter { /// its original state. void discardRewrites(Block *block); - /// Fully replace uses of the old arguments with the new, materializing cast - /// operations as necessary. + /// Fully replace uses of the old arguments with the new. void applyRewrites(ConversionValueMapping &mapping); + /// Materialize any necessary conversions for converted arguments that have + /// live users, using the provided `findLiveUser` to search for a user that + /// survives the conversion process. + LogicalResult + materializeLiveConversions(ConversionValueMapping &mapping, + OpBuilder &builder, + function_ref findLiveUser); + //===--------------------------------------------------------------------===// // Conversion //===--------------------------------------------------------------------===// @@ -307,7 +344,6 @@ void ArgConverter::discardRewrites(Block *block) { void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { for (auto &info : conversionInfo) { - Block *newBlock = info.first; ConvertedBlockInfo &blockInfo = info.second; Block *origBlock = blockInfo.origBlock; @@ -318,24 +354,8 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { // Handle the case of a 1->0 value mapping. if (!argInfo) { - // If a replacement value was given for this argument, use that to - // replace all uses. - auto argReplacementValue = mapping.lookupOrDefault(origArg); - if (argReplacementValue != origArg) { - origArg.replaceAllUsesWith(argReplacementValue); - continue; - } - // If there are any dangling uses then replace the argument with one - // generated by the type converter. This is necessary as the cast must - // persist in the IR after conversion. - if (!origArg.use_empty()) { - rewriter.setInsertionPointToStart(newBlock); - Value newArg = blockInfo.converter->materializeConversion( - rewriter, origArg.getLoc(), origArg.getType(), llvm::None); - assert(newArg && - "Couldn't materialize a block argument after 1->0 conversion"); + if (Value newArg = mapping.lookupOrNull(origArg)) origArg.replaceAllUsesWith(newArg); - } continue; } @@ -355,6 +375,59 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { } } +LogicalResult ArgConverter::materializeLiveConversions( + ConversionValueMapping &mapping, OpBuilder &builder, + function_ref findLiveUser) { + for (auto &info : conversionInfo) { + Block *newBlock = info.first; + ConvertedBlockInfo &blockInfo = info.second; + Block *origBlock = blockInfo.origBlock; + + // Process the remapping for each of the original arguments. + for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { + // FIXME: We should run the below checks even if the type conversion was + // 1->N, but a lot of existing lowering rely on the block argument being + // blindly replaced. Those usages should be updated, and this if should be + // removed. + if (blockInfo.argInfo[i]) + continue; + + // If the type of this argument changed and the argument is still live, we + // need to materialize a conversion. + BlockArgument origArg = origBlock->getArgument(i); + auto argReplacementValue = mapping.lookupOrDefault(origArg); + bool isDroppedArg = argReplacementValue == origArg; + if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg) + continue; + Operation *liveUser = findLiveUser(origArg); + if (!liveUser) + continue; + + if (OpResult result = argReplacementValue.dyn_cast()) + rewriter.setInsertionPointAfter(result.getOwner()); + else + rewriter.setInsertionPointToStart(newBlock); + Value newArg = blockInfo.converter->materializeSourceConversion( + rewriter, origArg.getLoc(), origArg.getType(), + isDroppedArg ? ValueRange() : ValueRange(argReplacementValue)); + if (!newArg) { + InFlightDiagnostic diag = + emitError(origArg.getLoc()) + << "failed to materialize conversion for block argument #" << i + << " that remained live after conversion, type was " + << origArg.getType(); + if (!isDroppedArg) + diag << ", with target type " << argReplacementValue.getType(); + diag.attachNote(liveUser->getLoc()) + << "see existing live user here: " << *liveUser; + return failure(); + } + mapping.map(origArg, newArg); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // Conversion @@ -417,8 +490,8 @@ Block *ArgConverter::applySignatureConversion( // to pack the new values. For 1->1 mappings, if there is no materialization // provided, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Value newArg = converter.materializeConversion(rewriter, origArg.getLoc(), - origArg.getType(), replArgs); + Value newArg = converter.materializeArgumentConversion( + rewriter, origArg.getLoc(), origArg.getType(), replArgs); if (!newArg) { assert(replArgs.size() == 1 && "couldn't materialize the result of 1->N conversion"); @@ -516,13 +589,15 @@ class OperationTransactionState { SmallVector successors; }; -/// This class represents one requested operation replacement via 'replaceOp'. +/// This class represents one requested operation replacement via 'replaceOp' or +/// 'eraseOp`. struct OpReplacement { OpReplacement() = default; - OpReplacement(ValueRange newValues) - : newValues(newValues.begin(), newValues.end()) {} + OpReplacement(TypeConverter *converter) : converter(converter) {} - SmallVector newValues; + /// An optional type converter that can be used to materialize conversions + /// between the new and old values if necessary. + TypeConverter *converter = nullptr; }; /// The kind of the block action performed during the rewrite. Actions can be @@ -611,9 +686,14 @@ struct ConversionPatternRewriterImpl { /// "numActionsToKeep" actions remains. void undoBlockActions(unsigned numActionsToKeep = 0); - /// Remap the given operands to those with potentially different types. - void remapValues(Operation::operand_range operands, - SmallVectorImpl &remapped); + /// Remap the given operands to those with potentially different types. The + /// provided type converter is used to ensure that the remapped types are + /// legal. Returns success if the operands could be remapped, failure + /// otherwise. + LogicalResult remapValues(Location loc, PatternRewriter &rewriter, + TypeConverter *converter, + Operation::operand_range operands, + SmallVectorImpl &remapped); /// Returns true if the given operation is ignored, and does not need to be /// converted. @@ -666,6 +746,11 @@ struct ConversionPatternRewriterImpl { void notifyRegionWasClonedBefore(iterator_range &blocks, Location origRegionLoc); + /// Notifies that a pattern match failed for the given reason. + LogicalResult + notifyMatchFailure(Location loc, + function_ref reasonCallback); + //===--------------------------------------------------------------------===// // State //===--------------------------------------------------------------------===// @@ -712,6 +797,10 @@ struct ConversionPatternRewriterImpl { /// explicitly provided. TypeConverter defaultTypeConverter; + /// The current conversion pattern that is being rewritten, or nullptr if + /// called from outside of a conversion pattern rewrite. + const ConversionPattern *currentConversionPattern = nullptr; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -759,11 +848,9 @@ void ConversionPatternRewriterImpl::discardRewrites() { void ConversionPatternRewriterImpl::applyRewrites() { // Apply all of the rewrites replacements requested during conversion. for (auto &repl : replacements) { - for (unsigned i = 0, e = repl.second.newValues.size(); i != e; ++i) { - if (auto newValue = repl.second.newValues[i]) - repl.first->getResult(i).replaceAllUsesWith( - mapping.lookupOrDefault(newValue)); - } + for (OpResult result : repl.first->getResults()) + if (Value newValue = mapping.lookupOrNull(result)) + result.replaceAllUsesWith(newValue); // If this operation defines any regions, drop any pending argument // rewrites. @@ -905,11 +992,61 @@ void ConversionPatternRewriterImpl::undoBlockActions( blockActions.resize(numActionsToKeep); } -void ConversionPatternRewriterImpl::remapValues( +LogicalResult ConversionPatternRewriterImpl::remapValues( + Location loc, PatternRewriter &rewriter, TypeConverter *converter, Operation::operand_range operands, SmallVectorImpl &remapped) { remapped.reserve(llvm::size(operands)); - for (Value operand : operands) - remapped.push_back(mapping.lookupOrDefault(operand)); + + SmallVector legalTypes; + for (auto it : llvm::enumerate(operands)) { + Value operand = it.value(); + Type origType = operand.getType(); + + // If a converter was provided, get the desired legal types for this + // operand. + Type desiredType; + if (converter) { + // If there is no legal conversion, fail to match this pattern. + legalTypes.clear(); + if (failed(converter->convertType(origType, legalTypes))) { + return notifyMatchFailure(loc, [=](Diagnostic &diag) { + diag << "unable to convert type for operand #" << it.index() + << ", type was " << origType; + }); + } + // TODO: There currently isn't any mechanism to do 1->N type conversion + // via the PatternRewriter replacement API, so for now we just ignore it. + if (legalTypes.size() == 1) + desiredType = legalTypes.front(); + } else { + // TODO: What we should do here is just set `desiredType` to `origType` + // and then handle the necessary type conversions after the conversion + // process has finished. Unfortunately a lot of patterns currently rely on + // receiving the new operands even if the types change, so we keep the + // original behavior here for now until all of the patterns relying on + // this get updated. + } + Value newOperand = mapping.lookupOrDefault(operand, desiredType); + + // Handle the case where the conversion was 1->1 and the new operand type + // isn't legal. + Type newOperandType = newOperand.getType(); + if (converter && desiredType && newOperandType != desiredType) { + // Attempt to materialize a conversion for this new value. + newOperand = converter->materializeTargetConversion( + rewriter, loc, desiredType, newOperand); + if (!newOperand) { + return notifyMatchFailure(loc, [=](Diagnostic &diag) { + diag << "unable to materialize a conversion for " + "operand #" + << it.index() << ", from " << newOperandType << " to " + << desiredType; + }); + } + } + remapped.push_back(newOperand); + } + return success(); } bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { @@ -987,16 +1124,22 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, Value newValue, result; for (auto it : llvm::zip(newValues, op->getResults())) { std::tie(newValue, result) = it; - if (!newValue) + if (!newValue) { resultChanged = true; - else - mapping.map(result, newValue); + continue; + } + // Remap, and check for any result type changes. + mapping.map(result, newValue); + resultChanged |= (newValue.getType() != result.getType()); } if (resultChanged) operationsWithChangedResults.push_back(replacements.size()); // Record the requested operation replacement. - replacements.insert(std::make_pair(op, OpReplacement(newValues))); + TypeConverter *converter = nullptr; + if (currentConversionPattern) + converter = currentConversionPattern->getTypeConverter(); + replacements.insert(std::make_pair(op, OpReplacement(converter))); // Mark this operation as recursively ignored so that we don't need to // convert any nested operations. @@ -1041,6 +1184,16 @@ void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( assert(succeeded(result) && "expected region to have no unreachable blocks"); } +LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure( + Location loc, function_ref reasonCallback) { + LLVM_DEBUG({ + Diagnostic diag(loc, DiagnosticSeverity::Remark); + reasonCallback(diag); + logger.startLine() << "** Failure : " << diag.str() << "\n"; + }); + return failure(); +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriter //===----------------------------------------------------------------------===// @@ -1200,12 +1353,7 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { /// PatternRewriter hook for notifying match failure reasons. LogicalResult ConversionPatternRewriter::notifyMatchFailure( Operation *op, function_ref reasonCallback) { - LLVM_DEBUG({ - Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); - reasonCallback(diag); - impl->logger.startLine() << "** Failure : " << diag.str() << "\n"; - }); - return failure(); + return impl->notifyMatchFailure(op->getLoc(), reasonCallback); } /// Return a reference to the internal implementation. @@ -1221,9 +1369,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - SmallVector operands; auto &dialectRewriter = static_cast(rewriter); - dialectRewriter.getImpl().remapValues(op->getOperands(), operands); + auto &rewriterImpl = dialectRewriter.getImpl(); + + // Track the current conversion pattern in the rewriter. + assert(!rewriterImpl.currentConversionPattern && + "already inside of a pattern rewrite"); + llvm::SaveAndRestore currentPatternGuard( + rewriterImpl.currentConversionPattern, this); + + // Remap the operands of the operation. + SmallVector operands; + if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter, + getTypeConverter(), op->getOperands(), + operands))) { + return failure(); + } return matchAndRewrite(op, operands, dialectRewriter); } @@ -1878,6 +2039,24 @@ struct OperationConverter { /// remaining artifacts and complete the conversion. LogicalResult finalize(ConversionPatternRewriter &rewriter); + /// Legalize the types of converted block arguments. + LogicalResult + legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl); + + /// Legalize an operation result that was marked as "erased". + LogicalResult + legalizeErasedResult(Operation *op, OpResult result, + ConversionPatternRewriterImpl &rewriterImpl); + + /// Legalize an operation result that was replaced with a value of a different + /// type. + LogicalResult + legalizeChangedResultType(Operation *op, OpResult result, Value newValue, + TypeConverter *replConverter, + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl); + /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -1961,33 +2140,145 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { LogicalResult OperationConverter::finalize(ConversionPatternRewriter &rewriter) { ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - auto isOpDead = [&](Operation *op) { return rewriterImpl.isOpIgnored(op); }; - // Process the operations with changed results. - for (unsigned replIdx : rewriterImpl.operationsWithChangedResults) { + // Legalize converted block arguments. + if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) + return failure(); + + // Process requested operation replacements. + for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); + i != e; ++i) { + unsigned replIdx = rewriterImpl.operationsWithChangedResults[i]; auto &repl = *(rewriterImpl.replacements.begin() + replIdx); - for (auto it : llvm::zip(repl.first->getResults(), repl.second.newValues)) { - Value result = std::get<0>(it), newValue = std::get<1>(it); + for (OpResult result : repl.first->getResults()) { + Value newValue = rewriterImpl.mapping.lookupOrNull(result); // If the operation result was replaced with null, all of the uses of this // value should be replaced. - if (newValue) + if (!newValue) { + if (failed(legalizeErasedResult(repl.first, result, rewriterImpl))) + return failure(); + continue; + } + + // Otherwise, check to see if the type of the result changed. + if (result.getType() == newValue.getType()) continue; - auto liveUserIt = llvm::find_if_not(result.getUsers(), isOpDead); - if (liveUserIt != result.user_end()) { - InFlightDiagnostic diag = repl.first->emitError() - << "failed to legalize operation '" - << repl.first->getName() - << "' marked as erased"; - diag.attachNote(liveUserIt->getLoc()) - << "found live user of result #" - << result.cast().getResultNumber() << ": " << *liveUserIt; + // Legalize this result. + rewriter.setInsertionPoint(repl.first); + if (failed(legalizeChangedResultType(repl.first, result, newValue, + repl.second.converter, rewriter, + rewriterImpl))) return failure(); - } + + // Update the end iterator for this loop in the case it was updated + // when legalizing generated conversion operations. + e = rewriterImpl.operationsWithChangedResults.size(); + } + } + return success(); +} + +LogicalResult OperationConverter::legalizeConvertedArgumentTypes( + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl) { + // Functor used to check if all users of a value will be dead after + // conversion. + auto findLiveUser = [&](Value val) { + auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + return liveUserIt == val.user_end() ? nullptr : *liveUserIt; + }; + + // Materialize any necessary conversions for converted block arguments that + // are still live. + size_t numCreatedOps = rewriterImpl.createdOps.size(); + if (failed(rewriterImpl.argConverter.materializeLiveConversions( + rewriterImpl.mapping, rewriter, findLiveUser))) + return failure(); + + // Legalize any newly created operations during argument materialization. + for (int i : llvm::seq(numCreatedOps, rewriterImpl.createdOps.size())) { + if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) { + return rewriterImpl.createdOps[i]->emitError() + << "failed to legalize conversion operation generated for block " + "argument that remained live after conversion"; + } + } + return success(); +} + +LogicalResult OperationConverter::legalizeErasedResult( + Operation *op, OpResult result, + ConversionPatternRewriterImpl &rewriterImpl) { + // If the operation result was replaced with null, all of the uses of this + // value should be replaced. + auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + if (liveUserIt != result.user_end()) { + InFlightDiagnostic diag = op->emitError("failed to legalize operation '") + << op->getName() << "' marked as erased"; + diag.attachNote(liveUserIt->getLoc()) + << "found live user of result #" << result.getResultNumber() << ": " + << *liveUserIt; + return failure(); + } + return success(); +} + +LogicalResult OperationConverter::legalizeChangedResultType( + Operation *op, OpResult result, Value newValue, + TypeConverter *replConverter, ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl) { + // Walk the users of this value to see if there are any live users that + // weren't replaced during conversion. + auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { + return rewriterImpl.isOpIgnored(user); + }); + if (liveUserIt == result.user_end()) + return success(); + + // If the replacement has a type converter, attempt to materialize a + // conversion back to the original type. + if (!replConverter) { + // TODO: We should emit an error here, similarly to the case where the + // result is replaced with null. Unfortunately a lot of existing + // patterns rely on this behavior, so until those patterns are updated + // we keep the legacy behavior here of just forwarding the new value. + return success(); + } + + // Track the number of created operations so that new ones can be legalized. + size_t numCreatedOps = rewriterImpl.createdOps.size(); + + // Materialize a conversion for this live result value. + Type resultType = result.getType(); + Value convertedValue = replConverter->materializeSourceConversion( + rewriter, op->getLoc(), resultType, newValue); + if (!convertedValue) { + InFlightDiagnostic diag = op->emitError() + << "failed to materialize conversion for result #" + << result.getResultNumber() << " of operation '" + << op->getName() + << "' that remained live after conversion"; + diag.attachNote(liveUserIt->getLoc()) + << "see existing live user here: " << *liveUserIt; + return failure(); + } + + // Legalize all of the newly created conversion operations. + for (int i : llvm::seq(numCreatedOps, rewriterImpl.createdOps.size())) { + if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) { + return op->emitError("failed to legalize conversion operation generated ") + << "for result #" << result.getResultNumber() << " of operation '" + << op->getName() << "' that remained live after conversion"; } } + rewriterImpl.mapping.map(result, convertedValue); return success(); } @@ -2136,11 +2427,11 @@ LogicalResult TypeConverter::convertSignatureArgs(TypeRange types, return success(); } -Value TypeConverter::materializeConversion(PatternRewriter &rewriter, - Location loc, Type resultType, - ValueRange inputs) { +Value TypeConverter::materializeConversion( + MutableArrayRef materializations, + OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) { for (MaterializationCallbackFn &fn : llvm::reverse(materializations)) - if (Optional result = fn(rewriter, resultType, inputs, loc)) + if (Optional result = fn(builder, resultType, inputs, loc)) return result.getValue(); return nullptr; } diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir index b8ebdfbf35f1c2..3b0a17be640b4c 100644 --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -75,15 +75,3 @@ func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) { %0 = rsqrt %arg0 : vector<4x3xf32> std.return } - -// ----- - -// This should not crash. The first operation cannot be converted, so the -// second should not match. This attempts to convert `return` to `llvm.return` -// and complains about non-LLVM types. -func @unknown_source() -> i32 { - %0 = "foo"() : () -> i32 - %1 = addi %0, %0 : i32 - // expected-error@+1 {{must be LLVM dialect type}} - return %1 : i32 -} diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir index 4e4bf06e6f7376..3d37f35b1c466e 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -57,9 +57,12 @@ spv.module Logical GLSL450 { // CHECK: [[CONST3:%.*]] = spv.constant 0 : i32 // CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]] // CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]] - // CHECK: [[ARG2:%.*]] = spv._address_of [[VAR2]] - // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]] - // CHECK: [[ARG0:%.*]] = spv._address_of [[VAR0]] + // CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]] + // CHECK: [[ARG2:%.*]] = spv.Bitcast [[ADDRESSARG2]] + // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]] + // CHECK: [[ARG1:%.*]] = spv.Bitcast [[ADDRESSARG1]] + // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]] + // CHECK: [[ARG0:%.*]] = spv.Bitcast [[ADDRESSARG0]] %0 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr, Input> %1 = spv.Load "Input" %0 : vector<3xi32> %2 = spv.CompositeExtract %1[0 : i32] : vector<3xi32> diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir new file mode 100644 index 00000000000000..c56b3c8ca1e2d7 --- /dev/null +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-opt %s -test-legalize-type-conversion -allow-unregistered-dialect -split-input-file -verify-diagnostics | FileCheck %s + +// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}} +func @test_invalid_arg_materialization(%arg0: i16) { + // expected-note@below {{see existing live user here}} + "foo.return"(%arg0) : (i16) -> () +} + +// ----- + +// expected-error@below {{failed to legalize conversion operation generated for block argument}} +func @test_invalid_arg_illegal_materialization(%arg0: i32) { + "foo.return"(%arg0) : (i32) -> () +} + +// ----- + +// CHECK-LABEL: func @test_valid_arg_materialization +func @test_valid_arg_materialization(%arg0: i64) { + // CHECK: %[[ARG:.*]] = "test.type_producer" + // CHECK: "foo.return"(%[[ARG]]) : (i64) + + "foo.return"(%arg0) : (i64) -> () +} + +// ----- + +func @test_invalid_result_materialization() { + // expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}} + %result = "test.type_producer"() : () -> f16 + + // expected-note@below {{see existing live user here}} + "foo.return"(%result) : (f16) -> () +} + +// ----- + +func @test_invalid_result_materialization() { + // expected-error@below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}} + %result = "test.type_producer"() : () -> f16 + + // expected-note@below {{see existing live user here}} + "foo.return"(%result) : (f16) -> () +} + +// ----- + +func @test_invalid_result_legalization() { + // expected-error@below {{failed to legalize conversion operation generated for result #0 of operation 'test.type_producer' that remained live after conversion}} + %result = "test.type_producer"() : () -> i16 + "foo.return"(%result) : (i16) -> () +} + +// ----- + +// CHECK-LABEL: func @test_valid_result_legalization +func @test_valid_result_legalization() { + // CHECK: %[[RESULT:.*]] = "test.type_producer"() : () -> f64 + // CHECK: %[[CAST:.*]] = "test.cast"(%[[RESULT]]) : (f64) -> f32 + // CHECK: "foo.return"(%[[CAST]]) : (f32) + + %result = "test.type_producer"() : () -> f32 + "foo.return"(%result) : (f32) -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 255b1c152a3652..5bc947fc8c9164 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -485,8 +485,9 @@ struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; TestTypeConverter() { addConversion(convertType); - addMaterialization(materializeCast); - addMaterialization(materializeOneToOneCast); + addArgumentMaterialization(materializeCast); + addArgumentMaterialization(materializeOneToOneCast); + addSourceMaterialization(materializeCast); } static LogicalResult convertType(Type t, SmallVectorImpl &results) { @@ -519,21 +520,20 @@ struct TestTypeConverter : public TypeConverter { /// Hook for materializing a conversion. This is necessary because we generate /// 1->N type mappings. - static Optional materializeCast(PatternRewriter &rewriter, - Type resultType, ValueRange inputs, - Location loc) { + static Optional materializeCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { if (inputs.size() == 1) return inputs[0]; - return rewriter.create(loc, resultType, inputs).getResult(); + return builder.create(loc, resultType, inputs).getResult(); } /// Materialize the cast for one-to-one conversion from i64 to f64. - static Optional materializeOneToOneCast(PatternRewriter &rewriter, + static Optional materializeOneToOneCast(OpBuilder &builder, IntegerType resultType, ValueRange inputs, Location loc) { if (resultType.getWidth() == 42 && inputs.size() == 1) - return rewriter.create(loc, resultType, inputs).getResult(); + return builder.create(loc, resultType, inputs).getResult(); return llvm::None; } }; @@ -742,6 +742,102 @@ struct TestUnknownRootOpDriver }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Test type conversions +//===----------------------------------------------------------------------===// + +namespace { +struct TestTypeConversionProducer + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(TestTypeProducerOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Type resultType = op.getType(); + if (resultType.isa()) + resultType = rewriter.getF64Type(); + else if (resultType.isInteger(16)) + resultType = rewriter.getIntegerType(64); + else + return failure(); + + rewriter.replaceOpWithNewOp(op, resultType); + return success(); + } +}; + +struct TestTypeConversionDriver + : public PassWrapper> { + void runOnOperation() override { + // Initialize the type converter. + TypeConverter converter; + + /// Add the legal set of type conversions. + converter.addConversion([](Type type) -> Type { + // Treat F64 as legal. + if (type.isF64()) + return type; + // Allow converting BF16/F16/F32 to F64. + if (type.isBF16() || type.isF16() || type.isF32()) + return FloatType::getF64(type.getContext()); + // Otherwise, the type is illegal. + return nullptr; + }); + converter.addConversion([](IntegerType type, SmallVectorImpl &) { + // Drop all integer types. + return success(); + }); + + /// Add the legal set of type materializations. + converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + // Allow casting from F64 back to F32. + if (!resultType.isF16() && inputs.size() == 1 && + inputs[0].getType().isF64()) + return builder.create(loc, resultType, inputs).getResult(); + // Allow producing an i32 or i64 from nothing. + if ((resultType.isInteger(32) || resultType.isInteger(64)) && + inputs.empty()) + return builder.create(loc, resultType); + // Allow producing an i64 from an integer. + if (resultType.isa() && inputs.size() == 1 && + inputs[0].getType().isa()) + return builder.create(loc, resultType, inputs).getResult(); + // Otherwise, fail. + return nullptr; + }); + + // Initialize the conversion target. + mlir::ConversionTarget target(getContext()); + target.addDynamicallyLegalOp([](TestTypeProducerOp op) { + return op.getType().isF64() || op.getType().isInteger(64); + }); + target.addDynamicallyLegalOp([&](FuncOp op) { + return converter.isSignatureLegal(op.getType()) && + converter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp([&](TestCastOp op) { + // Allow casts from F64 to F32. + return (*op.operand_type_begin()).isF64() && op.getType().isF32(); + }); + + // Initialize the set of rewrite patterns. + OwningRewritePatternList patterns; + patterns.insert(converter, &getContext()); + mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), + converter); + + if (failed(applyPartialConversion(getOperation(), target, patterns))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// PassRegistration +//===----------------------------------------------------------------------===// + namespace mlir { void registerPatternsTestPass() { PassRegistration("test-return-type", @@ -766,5 +862,9 @@ void registerPatternsTestPass() { PassRegistration( "test-legalize-unknown-root-patterns", "Test public remapped value mechanism in ConversionPatternRewriter"); + + PassRegistration( + "test-legalize-type-conversion", + "Test various type conversion functionalities in DialectConversion"); } } // namespace mlir