Skip to content

Commit 779ac0f

Browse files
cychen2021wsmosesZuseZ4
authored
Support Rust types by retrieving them from debug info (rust-lang#307)
* Complete the prototype of Rust debug info parser * Change the uses of TypeTree class to a more appropriate pattern * Complete the Rust debug info parser for pointers and arrays * Complete the support for structs * Complete the debug info type parser for tuples and fix some bugs * Add support for Vecs and Boxes * Document the Rust debug info parsing code * Wrap Rust type info parsing into an if-statement, so it won't be invoked when the Rust type option is switched off * Add support for unions * Add a regression test for rust f32 type * Update rustf32.ll * Reduce the rustf32.ll test case to the minimum * Add build dir generated by Clion and .idea dir to .gitignore * Add a regression test for rust f64 type * Add regression tests for rust integer types * Add a test case for the rust struct type * Delete some unnecessary chars from f32 and i8's test cases * Add test cases for the rust array type * Add a test case for the rust Vec type * Add a test case for the rust Box type * Add test cases for rust ref types * Add test cases for rust pointer types * Fix a bug related with the union type * Add a regression test for the rust union type * Revert "Add build dir generated by Clion and .idea dir to .gitignore" This reverts commit b08016cb93e8ccf5cde8034c03a5b7f2ba2a185b. * Make the rust type parser's code compatible with LLVM version under 9 * Make the test cases compatible with LLVM version under 9 * Change some code format Co-authored-by: William Moses <gh@wsmoses.com> Co-authored-by: Manuel Drehwald <git@manuel.drehwald.info>
1 parent deba550 commit 779ac0f

27 files changed

+1655
-2
lines changed

enzyme/Enzyme/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ set(CMAKE_CXX_STANDARD 17)
99
set(CMAKE_CXX_STANDARD_REQUIRED ON)
1010

1111
list(APPEND ENZYME_SRC SCEV/ScalarEvolutionExpander.cpp)
12-
list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp TypeAnalysis/TypeAnalysisPrinter.cpp)
12+
list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp TypeAnalysis/TypeAnalysisPrinter.cpp TypeAnalysis/RustDebugInfo.cpp)
1313

1414
if (${LLVM_VERSION_MAJOR} LESS 8)
1515
add_llvm_loadable_module( LLVMEnzyme-${LLVM_VERSION_MAJOR}
@@ -68,7 +68,7 @@ endif()
6868
if (${ENZYME_EXTERNAL_SHARED_LIB})
6969
add_library( Enzyme-${LLVM_VERSION_MAJOR}
7070
SHARED
71-
${ENZYME_SRC}
71+
${ENZYME_SRC}
7272
)
7373
target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} LLVM)
7474
install(TARGETS Enzyme-${LLVM_VERSION_MAJOR}
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
//===- RustDebugInfo.cpp - Implementaion of Rust Debug Info Parser ---===//
2+
//
3+
// Enzyme Project
4+
//
5+
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
// If using this code in an academic setting, please cite the following:
10+
// @incollection{enzymeNeurips,
11+
// title = {Instead of Rewriting Foreign Code for Machine Learning,
12+
// Automatically Synthesize Fast Gradients},
13+
// author = {Moses, William S. and Churavy, Valentin},
14+
// booktitle = {Advances in Neural Information Processing Systems 33},
15+
// year = {2020},
16+
// note = {To appear in},
17+
// }
18+
//
19+
//===-------------------------------------------------------------------===//
20+
//
21+
// This file implement the Rust debug info parsing function. It will get the
22+
// description of types from debug info of an instruction and pass it to
23+
// concrete functions according to the kind of a description and construct
24+
// the type tree recursively.
25+
//
26+
//===-------------------------------------------------------------------===//
27+
#include "llvm/IR/DataLayout.h"
28+
#include "llvm/IR/DebugInfo.h"
29+
#include "llvm/Support/CommandLine.h"
30+
31+
#include "RustDebugInfo.h"
32+
33+
TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL);
34+
35+
TypeTree parseDIType(DIBasicType &Type, Instruction &I, DataLayout &DL) {
36+
std::string TypeName = Type.getName().str();
37+
TypeTree Result;
38+
if (TypeName == "f64") {
39+
Result = TypeTree(Type::getDoubleTy(I.getContext())).Only(0);
40+
} else if (TypeName == "f32") {
41+
Result = TypeTree(Type::getFloatTy(I.getContext())).Only(0);
42+
} else if (TypeName == "i8" || TypeName == "i16" || TypeName == "i32" ||
43+
TypeName == "i64" || TypeName == "isize" || TypeName == "u8" ||
44+
TypeName == "u16" || TypeName == "u32" || TypeName == "u64" ||
45+
TypeName == "usize" || TypeName == "i128" || TypeName == "u128") {
46+
Result = TypeTree(ConcreteType(BaseType::Integer)).Only(0);
47+
} else {
48+
Result = TypeTree(ConcreteType(BaseType::Unknown)).Only(0);
49+
}
50+
return Result;
51+
}
52+
53+
TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) {
54+
TypeTree Result;
55+
if (Type.getTag() == dwarf::DW_TAG_array_type) {
56+
#if LLVM_VERSION_MAJOR >= 9
57+
DIType *SubType = Type.getBaseType();
58+
#else
59+
DIType *SubType = Type.getBaseType().resolve();
60+
#endif
61+
TypeTree SubTT = parseDIType(*SubType, I, DL);
62+
size_t Align = Type.getAlignInBytes();
63+
size_t SubSize = SubType->getSizeInBits() / 8;
64+
size_t Size = Type.getSizeInBits() / 8;
65+
DINodeArray Subranges = Type.getElements();
66+
size_t pos = 0;
67+
for (auto r : Subranges) {
68+
DISubrange *Subrange = dyn_cast<DISubrange>(r);
69+
if (auto Count = Subrange->getCount().get<ConstantInt *>()) {
70+
int64_t count = Count->getSExtValue();
71+
if (count == -1) {
72+
break;
73+
}
74+
for (int64_t i = 0; i < count; i++) {
75+
Result |= SubTT.ShiftIndices(DL, 0, Size, pos);
76+
size_t tmp = pos + SubSize;
77+
if (tmp % Align != 0) {
78+
pos = (tmp / Align + 1) * Align;
79+
} else {
80+
pos = tmp;
81+
}
82+
}
83+
} else {
84+
assert(0 && "There shouldn't be non-constant-size arrays in Rust");
85+
}
86+
}
87+
return Result;
88+
} else if (Type.getTag() == dwarf::DW_TAG_structure_type ||
89+
Type.getTag() == dwarf::DW_TAG_union_type) {
90+
DINodeArray Elements = Type.getElements();
91+
size_t Size = Type.getSizeInBits() / 8;
92+
bool firstSubTT = true;
93+
for (auto e : Elements) {
94+
DIType *SubType = dyn_cast<DIDerivedType>(e);
95+
assert(SubType->getTag() == dwarf::DW_TAG_member);
96+
TypeTree SubTT = parseDIType(*SubType, I, DL);
97+
size_t Offset = SubType->getOffsetInBits() / 8;
98+
SubTT = SubTT.ShiftIndices(DL, 0, Size, Offset);
99+
if (Type.getTag() == dwarf::DW_TAG_structure_type) {
100+
Result |= SubTT;
101+
} else {
102+
if (firstSubTT) {
103+
Result = SubTT;
104+
} else {
105+
Result &= SubTT;
106+
}
107+
}
108+
if (firstSubTT) {
109+
firstSubTT = !firstSubTT;
110+
}
111+
}
112+
return Result;
113+
} else {
114+
assert(0 && "Composite types other than arrays, structs and unions are not "
115+
"supported by Rust debug info parser");
116+
}
117+
}
118+
119+
TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) {
120+
if (Type.getTag() == dwarf::DW_TAG_pointer_type) {
121+
TypeTree Result(BaseType::Pointer);
122+
#if LLVM_VERSION_MAJOR >= 9
123+
DIType *SubType = Type.getBaseType();
124+
#else
125+
DIType *SubType = Type.getBaseType().resolve();
126+
#endif
127+
TypeTree SubTT = parseDIType(*SubType, I, DL);
128+
if (isa<DIBasicType>(SubType)) {
129+
Result |= SubTT.ShiftIndices(DL, 0, 1, -1);
130+
} else {
131+
Result |= SubTT;
132+
}
133+
return Result.Only(0);
134+
} else if (Type.getTag() == dwarf::DW_TAG_member) {
135+
#if LLVM_VERSION_MAJOR >= 9
136+
DIType *SubType = Type.getBaseType();
137+
#else
138+
DIType *SubType = Type.getBaseType().resolve();
139+
#endif
140+
TypeTree Result = parseDIType(*SubType, I, DL);
141+
return Result;
142+
} else {
143+
assert(0 && "Derived types other than pointers and members are not "
144+
"supported by Rust debug info parser");
145+
}
146+
}
147+
148+
TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) {
149+
if (Type.getSizeInBits() == 0) {
150+
return TypeTree();
151+
}
152+
153+
if (auto BT = dyn_cast<DIBasicType>(&Type)) {
154+
return parseDIType(*BT, I, DL);
155+
} else if (auto CT = dyn_cast<DICompositeType>(&Type)) {
156+
return parseDIType(*CT, I, DL);
157+
} else if (auto DT = dyn_cast<DIDerivedType>(&Type)) {
158+
return parseDIType(*DT, I, DL);
159+
} else {
160+
assert(0 && "Types other than floating-points, integers, arrays, pointers, "
161+
"slices, and structs are not supported by debug info parser");
162+
}
163+
}
164+
165+
bool isU8PointerType(DIType &type) {
166+
if (type.getTag() == dwarf::DW_TAG_pointer_type) {
167+
auto PTy = dyn_cast<DIDerivedType>(&type);
168+
#if LLVM_VERSION_MAJOR >= 9
169+
DIType *SubType = PTy->getBaseType();
170+
#else
171+
DIType *SubType = PTy->getBaseType().resolve();
172+
#endif
173+
if (auto BTy = dyn_cast<DIBasicType>(SubType)) {
174+
std::string name = BTy->getName().str();
175+
if (name == "u8") {
176+
return true;
177+
}
178+
}
179+
}
180+
return false;
181+
}
182+
183+
TypeTree parseDIType(DbgDeclareInst &I, DataLayout &DL) {
184+
#if LLVM_VERSION_MAJOR >= 9
185+
DIType *type = I.getVariable()->getType();
186+
#else
187+
DIType *type = I.getVariable()->getType().resolve();
188+
#endif
189+
190+
// If the type is *u8, do nothing, since the underlying type of data pointed
191+
// by a *u8 can be anything
192+
if (isU8PointerType(*type)) {
193+
return TypeTree();
194+
}
195+
TypeTree Result = parseDIType(*type, I, DL);
196+
return Result;
197+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===- RustDebugInfo.h - Declaration of Rust Debug Info Parser -------===//
2+
//
3+
// Enzyme Project
4+
//
5+
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
// If using this code in an academic setting, please cite the following:
10+
// @incollection{enzymeNeurips,
11+
// title = {Instead of Rewriting Foreign Code for Machine Learning,
12+
// Automatically Synthesize Fast Gradients},
13+
// author = {Moses, William S. and Churavy, Valentin},
14+
// booktitle = {Advances in Neural Information Processing Systems 33},
15+
// year = {2020},
16+
// note = {To appear in},
17+
// }
18+
//
19+
//===-------------------------------------------------------------------===//
20+
//
21+
// This file contains the declaration of the Rust debug info parsing function
22+
// which parses the debug info appended to LLVM IR generated by rustc and
23+
// extracts useful type info from it. The type info will be used to initialize
24+
// the following type analysis.
25+
//
26+
//===-------------------------------------------------------------------===//
27+
#ifndef ENZYME_RUSTDEBUGINFO_H
28+
#define ENZYME_RUSTDEBUGINFO_H 1
29+
30+
#include "llvm/IR/Instructions.h"
31+
#include "llvm/IR/IntrinsicInst.h"
32+
33+
using namespace llvm;
34+
35+
#include "TypeTree.h"
36+
37+
/// Construct the type tree from debug info of an instruction
38+
TypeTree parseDIType(DbgDeclareInst &I, DataLayout &DL);
39+
40+
#endif // ENZYME_RUSTDEBUGINFO_H

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "../FunctionUtils.h"
4949
#include "../LibraryFuncs.h"
5050

51+
#include "RustDebugInfo.h"
5152
#include "TBAA.h"
5253

5354
extern "C" {
@@ -1645,6 +1646,10 @@ void TypeAnalyzer::visitBitCastInst(BitCastInst &I) {
16451646
Type *et1 = cast<PointerType>(I.getType())->getElementType();
16461647
Type *et2 = cast<PointerType>(I.getOperand(0)->getType())->getElementType();
16471648

1649+
TypeTree Debug = getAnalysis(I.getOperand(0)).Data0();
1650+
DataLayout DL = fntypeinfo.Function->getParent()->getDataLayout();
1651+
TypeTree Debug1 = Debug.KeepForCast(DL, et2, et1);
1652+
16481653
if (direction & DOWN)
16491654
updateAnalysis(
16501655
&I,
@@ -4546,6 +4551,9 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) {
45464551
}
45474552

45484553
analysis.prepareArgs();
4554+
if (RustTypeRules) {
4555+
analysis.considerRustDebugInfo();
4556+
}
45494557
analysis.considerTBAA();
45504558
analysis.run();
45514559

@@ -4726,6 +4734,24 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val,
47264734
return dt;
47274735
}
47284736

4737+
/// Parse the debug info generated by rustc and retrieve useful type info if
4738+
/// possible
4739+
void TypeAnalyzer::considerRustDebugInfo() {
4740+
DataLayout DL = fntypeinfo.Function->getParent()->getDataLayout();
4741+
for (BasicBlock &BB : *fntypeinfo.Function) {
4742+
for (Instruction &I : BB) {
4743+
if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(&I)) {
4744+
TypeTree TT = parseDIType(*DDI, DL);
4745+
if (!TT.isKnown()) {
4746+
continue;
4747+
}
4748+
TT |= TypeTree(BaseType::Pointer);
4749+
updateAnalysis(DDI->getAddress(), TT.Only(-1), DDI);
4750+
}
4751+
}
4752+
}
4753+
}
4754+
47294755
Function *TypeResults::getFunction() const {
47304756
return analyzer.fntypeinfo.Function;
47314757
}

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ class TypeAnalyzer : public llvm::InstVisitor<TypeAnalyzer> {
255255
/// Analyze type info given by the TBAA, possibly adding to work queue
256256
void considerTBAA();
257257

258+
/// Parse the debug info generated by rustc and retrieve useful type info if
259+
/// possible
260+
void considerRustDebugInfo();
261+
258262
/// Run the interprocedural type analysis starting from this function
259263
void run();
260264

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme-rust-type -print-type-analysis -type-analysis-func=callee -o /dev/null | FileCheck %s
2+
3+
4+
5+
declare void @llvm.dbg.declare(metadata, metadata, metadata)
6+
7+
define internal void @callee(i8* %arg) !dbg !373 {
8+
start:
9+
%t = bitcast i8* %arg to [2 x [2 x [2 x float]]]*
10+
call void @llvm.dbg.declare(metadata [2 x [2 x [2 x float]]]* %t, metadata !384, metadata !DIExpression()), !dbg !385
11+
ret void
12+
}
13+
14+
!llvm.module.flags = !{!14, !15, !16, !17}
15+
!llvm.dbg.cu = !{!18}
16+
17+
!0 = !DIGlobalVariableExpression(var: !1, expr: !DIExpression())
18+
!1 = distinct !DIGlobalVariable(name: "vtable", scope: null, file: !2, type: !3, isLocal: true, isDefinition: true)
19+
!2 = !DIFile(filename: "<unknown>", directory: "")
20+
!3 = !DICompositeType(tag: DW_TAG_structure_type, name: "vtable", file: !2, align: 64, flags: DIFlagArtificial, elements: !4, vtableHolder: !5, identifier: "vtable")
21+
!4 = !{}
22+
!5 = !DICompositeType(tag: DW_TAG_structure_type, name: "{closure#0}", scope: !6, file: !2, size: 64, align: 64, elements: !9, templateParams: !4, identifier: "c211ca2a5a4c8dd717d1e5fba4a6ae0")
23+
!6 = !DINamespace(name: "lang_start", scope: !7)
24+
!7 = !DINamespace(name: "rt", scope: !8)
25+
!8 = !DINamespace(name: "std", scope: null)
26+
!9 = !{!10}
27+
!10 = !DIDerivedType(tag: DW_TAG_member, name: "main", scope: !5, file: !2, baseType: !11, size: 64, align: 64)
28+
!11 = !DIDerivedType(tag: DW_TAG_pointer_type, name: "fn()", baseType: !12, size: 64, align: 64, dwarfAddressSpace: 0)
29+
!12 = !DISubroutineType(types: !13)
30+
!13 = !{null}
31+
!14 = !{i32 7, !"PIC Level", i32 2}
32+
!15 = !{i32 7, !"PIE Level", i32 2}
33+
!16 = !{i32 2, !"RtLibUseGOT", i32 1}
34+
!17 = !{i32 2, !"Debug Info Version", i32 3}
35+
!18 = distinct !DICompileUnit(language: DW_LANG_Rust, file: !19, producer: "clang LLVM (rustc version 1.56.0 (09c42c458 2021-10-18))", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, enums: !20, globals: !37)
36+
!19 = !DIFile(filename: "rust3darray.rs", directory: "/home/nomanous/Space/Tmp/Enzyme")
37+
!20 = !{!21, !28}
38+
!21 = !DICompositeType(tag: DW_TAG_enumeration_type, name: "Result", scope: !22, file: !2, baseType: !24, size: 8, align: 8, elements: !25)
39+
!22 = !DINamespace(name: "result", scope: !23)
40+
!23 = !DINamespace(name: "core", scope: null)
41+
!24 = !DIBasicType(name: "u8", size: 8, encoding: DW_ATE_unsigned)
42+
!25 = !{!26, !27}
43+
!26 = !DIEnumerator(name: "Ok", value: 0)
44+
!27 = !DIEnumerator(name: "Err", value: 1)
45+
!28 = !DICompositeType(tag: DW_TAG_enumeration_type, name: "Alignment", scope: !29, file: !2, baseType: !24, size: 8, align: 8, elements: !32)
46+
!29 = !DINamespace(name: "v1", scope: !30)
47+
!30 = !DINamespace(name: "rt", scope: !31)
48+
!31 = !DINamespace(name: "fmt", scope: !23)
49+
!32 = !{!33, !34, !35, !36}
50+
!33 = !DIEnumerator(name: "Left", value: 0)
51+
!34 = !DIEnumerator(name: "Right", value: 1)
52+
!35 = !DIEnumerator(name: "Center", value: 2)
53+
!36 = !DIEnumerator(name: "Unknown", value: 3)
54+
!37 = !{!0}
55+
!156 = !DIBasicType(name: "f32", size: 32, encoding: DW_ATE_float)
56+
!373 = distinct !DISubprogram(name: "callee", linkageName: "_ZN11rust3darray6callee17h37b114a70360ce19E", scope: !375, file: !374, line: 1, type: !376, scopeLine: 1, flags: DIFlagPrototyped, unit: !18, templateParams: !4, retainedNodes: !383)
57+
!374 = !DIFile(filename: "rust3darray.rs", directory: "/home/nomanous/Space/Tmp/Enzyme", checksumkind: CSK_MD5, checksum: "adf66a1fcb26c178e41abd9c50aa582a")
58+
!375 = !DINamespace(name: "rust3darray", scope: null)
59+
!376 = !DISubroutineType(types: !377)
60+
!377 = !{!156, !378}
61+
!378 = !DICompositeType(tag: DW_TAG_array_type, baseType: !379, size: 256, align: 32, elements: !381)
62+
!379 = !DICompositeType(tag: DW_TAG_array_type, baseType: !380, size: 128, align: 32, elements: !381)
63+
!380 = !DICompositeType(tag: DW_TAG_array_type, baseType: !156, size: 64, align: 32, elements: !381)
64+
!381 = !{!382}
65+
!382 = !DISubrange(count: 2, lowerBound: 0)
66+
!383 = !{!384}
67+
!384 = !DILocalVariable(name: "t", arg: 1, scope: !373, file: !374, line: 1, type: !378)
68+
!385 = !DILocation(line: 1, column: 11, scope: !373)
69+
70+
; CHECK: callee - {} |{}:{}
71+
; CHECK-NEXT: i8* %arg: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float, [-1,12]:Float@float, [-1,16]:Float@float, [-1,20]:Float@float, [-1,24]:Float@float, [-1,28]:Float@float}
72+
; CHECK-NEXT: start
73+
; CHECK-NEXT: %t = bitcast i8* %arg to [2 x [2 x [2 x float]]]*: {[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Float@float, [-1,12]:Float@float, [-1,16]:Float@float, [-1,20]:Float@float, [-1,24]:Float@float, [-1,28]:Float@float}
74+
; CHECK-NEXT: call void @llvm.dbg.declare(metadata [2 x [2 x [2 x float]]]* %t, metadata !50, metadata !DIExpression()), !dbg !51: {}
75+
; CHECK-NEXT: ret void: {}

0 commit comments

Comments
 (0)