Skip to content

Commit 2b176de

Browse files
committed
cleaning up code
1 parent ed8ae06 commit 2b176de

File tree

2 files changed

+97
-125
lines changed

2 files changed

+97
-125
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 92 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -15,56 +15,24 @@ pub(crate) fn handle_gpu_code<'ll>(
1515
_cgcx: &CodegenContext<LlvmCodegenBackend>,
1616
cx: &'ll SimpleCx<'_>,
1717
) {
18-
let (offload_entry_ty, at_one, begin, update, end, tgt_bin_desc, fn_ty) = gen_globals(&cx);
19-
2018
let mut o_types = vec![];
2119
let mut kernels = vec![];
20+
let offload_entry_ty = add_tgt_offload_entry(&cx);
2221
for num in 0..9 {
2322
let kernel = cx.get_function(&format!("kernel_{num}"));
2423
if let Some(kernel) = kernel {
2524
o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
2625
kernels.push(kernel);
2726
}
2827
}
29-
gen_call_handling(&cx, &kernels, at_one, begin, update, end, tgt_bin_desc, fn_ty, &o_types);
30-
}
3128

32-
// The meaning of the __tgt_offload_entry (as per llvm docs) is
33-
// Type, Identifier, Description
34-
// void*, addr, Address of global symbol within device image (function or global)
35-
// char*, name, Name of the symbol
36-
// size_t, size, Size of the entry info (0 if it is a function)
37-
// int32_t, flags, Flags associated with the entry (see Target Region Entry Flags)
38-
// int32_t, reserved, Reserved, to be used by the runtime library.
39-
pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
40-
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
41-
let tptr = cx.type_ptr();
42-
let ti64 = cx.type_i64();
43-
let ti32 = cx.type_i32();
44-
let ti16 = cx.type_i16();
45-
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
46-
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
47-
offload_entry_ty
29+
gen_call_handling(&cx, &kernels, &o_types);
4830
}
4931

50-
fn gen_globals<'ll>(
51-
cx: &'ll SimpleCx<'_>,
52-
) -> (
53-
&'ll llvm::Type,
54-
&'ll llvm::Value,
55-
&'ll llvm::Value,
56-
&'ll llvm::Value,
57-
&'ll llvm::Value,
58-
&'ll llvm::Type,
59-
&'ll llvm::Type,
60-
) {
61-
let offload_entry_ty = add_tgt_offload_entry(&cx);
62-
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
63-
let tptr = cx.type_ptr();
64-
let ti64 = cx.type_i64();
65-
let ti32 = cx.type_i32();
66-
let tarr = cx.type_array(ti32, 3);
67-
32+
// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
33+
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
34+
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
35+
fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
6836
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
6937
let unknown_txt = ";unknown;unknown;0;0;;";
7038
let c_entry_name = CString::new(unknown_txt).unwrap();
@@ -87,11 +55,33 @@ fn gen_globals<'ll>(
8755
cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
8856
let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
8957
llvm::set_alignment(at_one, Align::EIGHT);
58+
at_one
59+
}
9060

91-
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
92-
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
93-
let tgt_bin_desc_name = cx.type_named_struct("struct.__tgt_bin_desc");
94-
cx.set_struct_body(tgt_bin_desc_name, &tgt_bin_desc_ty, false);
61+
// The meaning of the __tgt_offload_entry (as per llvm docs) is
62+
// Type, Identifier, Description
63+
// void*, addr, Address of global symbol within device image (function or global)
64+
// char*, name, Name of the symbol
65+
// size_t, size, Size of the entry info (0 if it is a function)
66+
// int32_t, flags, Flags associated with the entry (see Target Region Entry Flags)
67+
// int32_t, reserved, Reserved, to be used by the runtime library.
68+
pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
69+
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
70+
let tptr = cx.type_ptr();
71+
let ti64 = cx.type_i64();
72+
let ti32 = cx.type_i32();
73+
let ti16 = cx.type_i16();
74+
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
75+
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
76+
offload_entry_ty
77+
}
78+
79+
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
80+
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
81+
let tptr = cx.type_ptr();
82+
let ti64 = cx.type_i64();
83+
let ti32 = cx.type_i32();
84+
let tarr = cx.type_array(ti32, 3);
9585

9686
// For each kernel to run on the gpu, we will later generate one entry of this type.
9787
// coppied from LLVM
@@ -111,47 +101,32 @@ fn gen_globals<'ll>(
111101

112102
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
113103
// For now we don't handle kernels, so for now we just add a global dummy
114-
// to make sure that the __tgt_offload_entrr is defined and handled correctly.
104+
// to make sure that the __tgt_offload_entry is defined and handled correctly.
115105
cx.declare_global("my_struct_global2", kernel_arguments_ty);
106+
}
107+
108+
fn gen_tgt_data_mappers<'ll>(
109+
cx: &'ll SimpleCx<'_>,
110+
) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
111+
let tptr = cx.type_ptr();
112+
let ti64 = cx.type_i64();
113+
let ti32 = cx.type_i32();
116114

117-
// Move data to the gpu
118-
let mapper_begin = "__tgt_target_data_begin_mapper";
119-
// Update data on the gpu, currently not used.
120-
let mapper_update = String::from("__tgt_target_data_update_mapper");
121-
// Move data from the GPU
122-
let mapper_end = String::from("__tgt_target_data_end_mapper");
123115
let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
124116
let mapper_fn_ty = cx.type_func(&args, cx.type_void());
125-
let foo = crate::declare::declare_simple_fn(
126-
&cx,
127-
&mapper_begin,
128-
llvm::CallConv::CCallConv,
129-
llvm::UnnamedAddr::No,
130-
llvm::Visibility::Default,
131-
mapper_fn_ty,
132-
);
133-
let bar = crate::declare::declare_simple_fn(
134-
&cx,
135-
&mapper_update,
136-
llvm::CallConv::CCallConv,
137-
llvm::UnnamedAddr::No,
138-
llvm::Visibility::Default,
139-
mapper_fn_ty,
140-
);
141-
let baz = crate::declare::declare_simple_fn(
142-
&cx,
143-
&mapper_end,
144-
llvm::CallConv::CCallConv,
145-
llvm::UnnamedAddr::No,
146-
llvm::Visibility::Default,
147-
mapper_fn_ty,
148-
);
117+
let mapper_begin = "__tgt_target_data_begin_mapper";
118+
let mapper_update = "__tgt_target_data_update_mapper";
119+
let mapper_end = "__tgt_target_data_end_mapper";
120+
let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
121+
let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
122+
let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
123+
149124
let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
150-
attributes::apply_to_llfn(foo, Function, &[nounwind]);
151-
attributes::apply_to_llfn(bar, Function, &[nounwind]);
152-
attributes::apply_to_llfn(baz, Function, &[nounwind]);
125+
attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
126+
attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
127+
attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
153128

154-
(offload_entry_ty, at_one, foo, bar, baz, tgt_bin_desc_name, mapper_fn_ty)
129+
(begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
155130
}
156131

157132
fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
@@ -223,10 +198,10 @@ fn gen_define_handling<'ll>(
223198

224199
let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
225200
let c_val = c_entry_name.as_bytes_with_nul();
226-
let foo = format!(".offloading.entry_name.{num}");
201+
let offload_entry_name = format!(".offloading.entry_name.{num}");
227202

228203
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
229-
let llglobal = add_unnamed_global(&cx, &foo, initializer, InternalLinkage);
204+
let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
230205
llvm::set_alignment(llglobal, Align::ONE);
231206
let c_section_name = CString::new(".llvm.rodata.offloading").unwrap();
232207
llvm::set_section(llglobal, &c_section_name);
@@ -259,6 +234,21 @@ fn gen_define_handling<'ll>(
259234
o_types
260235
}
261236

237+
fn declare_offload_fn<'ll>(
238+
cx: &'ll SimpleCx<'_>,
239+
name: &str,
240+
ty: &'ll llvm::Type,
241+
) -> &'ll llvm::Value {
242+
crate::declare::declare_simple_fn(
243+
cx,
244+
name,
245+
llvm::CallConv::CCallConv,
246+
llvm::UnnamedAddr::No,
247+
llvm::Visibility::Default,
248+
ty,
249+
)
250+
}
251+
262252
// For each kernel *call*, we now use some of our previous declared globals to move data to and from
263253
// the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function
264254
// from main is intended to run on the GPU. For now, we only handle the data transfer part of it.
@@ -282,14 +272,18 @@ fn gen_define_handling<'ll>(
282272
fn gen_call_handling<'ll>(
283273
cx: &'ll SimpleCx<'_>,
284274
_kernels: &[&'ll llvm::Value],
285-
s_ident_t: &'ll llvm::Value,
286-
begin: &'ll llvm::Value,
287-
_update: &'ll llvm::Value,
288-
end: &'ll llvm::Value,
289-
tgt_bin_desc: &'ll llvm::Type,
290-
fn_ty: &'ll llvm::Type,
291275
o_types: &[&'ll llvm::Value],
292276
) {
277+
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
278+
let tptr = cx.type_ptr();
279+
let ti32 = cx.type_i32();
280+
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
281+
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
282+
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
283+
284+
gen_tgt_kernel_global(&cx);
285+
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
286+
293287
let main_fn = cx.get_function("main");
294288
if let Some(main_fn) = main_fn {
295289
let kernel_name = "kernel_1";
@@ -351,39 +345,17 @@ fn gen_call_handling<'ll>(
351345
Align::from_bytes(8).unwrap(),
352346
);
353347

354-
let tptr = cx.type_ptr();
355-
let mapper_fn_ty = cx.type_func(&[tptr], cx.type_void());
356-
let foo = crate::declare::declare_simple_fn(
357-
&cx,
358-
&"__tgt_register_lib",
359-
llvm::CallConv::CCallConv,
360-
llvm::UnnamedAddr::No,
361-
llvm::Visibility::Default,
362-
mapper_fn_ty,
363-
);
364-
let bar = crate::declare::declare_simple_fn(
365-
&cx,
366-
&"__tgt_unregister_lib",
367-
llvm::CallConv::CCallConv,
368-
llvm::UnnamedAddr::No,
369-
llvm::Visibility::Default,
370-
mapper_fn_ty,
371-
);
348+
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
349+
let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
350+
let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
372351
let init_ty = cx.type_func(&[], cx.type_void());
373-
let baz = crate::declare::declare_simple_fn(
374-
&cx,
375-
&"__tgt_init_all_rtls",
376-
llvm::CallConv::CCallConv,
377-
llvm::UnnamedAddr::No,
378-
llvm::Visibility::Default,
379-
init_ty,
380-
);
381-
382-
builder.call(mapper_fn_ty, foo, &[tgt_bin_desc_alloca], None);
383-
builder.call(init_ty, baz, &[], None);
352+
let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
384353

385354
// call void @__tgt_register_lib(ptr noundef %6)
355+
builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
386356
// call void @__tgt_init_all_rtls()
357+
builder.call(init_ty, init_rtls_decl, &[], None);
358+
387359
for i in 0..num_args {
388360
let idx = cx.get_const_i32(i);
389361
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
@@ -401,6 +373,7 @@ fn gen_call_handling<'ll>(
401373

402374
let nullptr = cx.const_null(cx.type_ptr());
403375
let o_type = o_types[0];
376+
let s_ident_t = generate_at_one(&cx);
404377
let args = vec![
405378
s_ident_t,
406379
cx.get_const_i64(u64::MAX),
@@ -412,7 +385,7 @@ fn gen_call_handling<'ll>(
412385
nullptr,
413386
nullptr,
414387
];
415-
builder.call(fn_ty, begin, &args, None);
388+
builder.call(fn_ty, begin_mapper_decl, &args, None);
416389

417390
// Step 4)
418391
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
@@ -434,14 +407,13 @@ fn gen_call_handling<'ll>(
434407
nullptr,
435408
nullptr,
436409
];
437-
builder.call(fn_ty, end, &args, None);
438-
builder.call(mapper_fn_ty, bar, &[tgt_bin_desc_alloca], None);
410+
builder.call(fn_ty, end_mapper_decl, &args, None);
411+
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
439412

413+
// With this we generated the following begin and end mappers. We could easily generate the
414+
// update mapper in an update.
440415
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
441416
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
442417
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
443-
// What is @1? Random but fixed:
444-
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
445-
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
446418
}
447419
}

tests/codegen/gpu_offload/gpu_host.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ fn main() {
2020
core::hint::black_box(&x);
2121
}
2222

23-
// CHECK: %struct.ident_t = type { i32, i32, i32, i32, ptr }
24-
// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
2523
// CHECK: %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
24+
// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
25+
// CHECK: %struct.ident_t = type { i32, i32, i32, i32, ptr }
2626
// CHECK: %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
2727

28-
// CHECK: @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
29-
// CHECK: @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
30-
// CHECK: @my_struct_global2 = external global %struct.__tgt_kernel_arguments
3128
// CHECK: @.offload_sizes.1 = private unnamed_addr constant [1 x i64] [i64 1024]
3229
// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 3]
3330
// CHECK: @.kernel_1.region_id = weak unnamed_addr constant i8 0
3431
// CHECK: @.offloading.entry_name.1 = internal unnamed_addr constant [9 x i8] c"kernel_1\00", section ".llvm.rodata.offloading", align 1
3532
// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1
33+
// CHECK: @my_struct_global2 = external global %struct.__tgt_kernel_arguments
34+
// CHECK: @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
35+
// CHECK: @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
3636

3737
// CHECK: Function Attrs:
3838
// CHECK-NEXT: define{{( dso_local)?}} void @main()

0 commit comments

Comments
 (0)