Skip to content

Commit 1a7e216

Browse files
committed
Docs and minor fixes
1 parent 0cee58b commit 1a7e216

File tree

5 files changed

+94
-32
lines changed

5 files changed

+94
-32
lines changed

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ impl KernelArgsTy {
170170
}
171171
}
172172

173+
// Contains LLVM values needed to manage offloading for a single kernel.
174+
pub(crate) struct OffloadKernelData<'ll> {
175+
pub offload_sizes: &'ll llvm::Value,
176+
pub memtransfer_types: &'ll llvm::Value,
177+
pub region_id: &'ll llvm::Value,
178+
pub offload_entry: &'ll llvm::Value,
179+
}
180+
173181
fn gen_tgt_data_mappers<'ll>(
174182
cx: &'ll SimpleCx<'_>,
175183
) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
@@ -238,7 +246,7 @@ pub(crate) fn gen_define_handling<'ll>(
238246
metadata: &[OffloadMetadata],
239247
types: &[&Type],
240248
symbol: &str,
241-
) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value) {
249+
) -> OffloadKernelData<'ll> {
242250
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
243251
// reference) types.
244252
let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
@@ -290,7 +298,7 @@ pub(crate) fn gen_define_handling<'ll>(
290298
let c_section_name = CString::new("llvm_offload_entries").unwrap();
291299
llvm::set_section(offload_entry, &c_section_name);
292300

293-
(offload_sizes, memtransfer_types, region_id, offload_entry)
301+
OffloadKernelData { offload_sizes, memtransfer_types, region_id, offload_entry }
294302
}
295303

296304
fn declare_offload_fn<'ll>(
@@ -330,14 +338,13 @@ fn declare_offload_fn<'ll>(
330338
pub(crate) fn gen_call_handling<'ll>(
331339
cx: &SimpleCx<'ll>,
332340
bb: &BasicBlock,
333-
offload_sizes: &'ll llvm::Value,
334-
offload_entry: &'ll llvm::Value,
335-
memtransfer_types: &'ll llvm::Value,
336-
region_id: &'ll llvm::Value,
341+
offload_data: &OffloadKernelData<'ll>,
337342
args: &[&'ll Value],
338343
types: &[&Type],
339344
metadata: &[OffloadMetadata],
340345
) {
346+
let OffloadKernelData { offload_sizes, offload_entry, memtransfer_types, region_id } =
347+
offload_data;
341348
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
342349
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
343350
let tptr = cx.type_ptr();
@@ -351,7 +358,11 @@ pub(crate) fn gen_call_handling<'ll>(
351358

352359
let mut builder = SBuilder::build(cx, bb);
353360

354-
// prevent these globals from being optimized away
361+
let num_args = types.len() as u64;
362+
let ip = unsafe { llvm::LLVMRustGetInsertPoint(&builder.llbuilder) };
363+
364+
// FIXME(Sa4dUs): dummy loads are a temp workaround, we should find a proper way to prevent these
365+
// variables from being optimized away
355366
for val in [offload_sizes, offload_entry] {
356367
unsafe {
357368
let dummy = llvm::LLVMBuildLoad2(
@@ -364,11 +375,13 @@ pub(crate) fn gen_call_handling<'ll>(
364375
}
365376
}
366377

367-
let num_args = types.len() as u64;
368-
369378
// Step 0)
370379
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
371380
// %6 = alloca %struct.__tgt_bin_desc, align 8
381+
let llfn = unsafe { llvm::LLVMGetBasicBlockParent(bb) };
382+
unsafe {
383+
llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, llfn);
384+
}
372385
let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
373386

374387
let ty = cx.type_array(cx.type_ptr(), num_args);
@@ -384,6 +397,9 @@ pub(crate) fn gen_call_handling<'ll>(
384397
let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
385398

386399
// Step 1)
400+
unsafe {
401+
llvm::LLVMRustRestoreInsertPoint(&builder.llbuilder, ip);
402+
}
387403
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
388404

389405
// Now we allocate once per function param, a copy to be passed to one of our maps.
@@ -458,9 +474,17 @@ pub(crate) fn gen_call_handling<'ll>(
458474

459475
// Step 2)
460476
let s_ident_t = generate_at_one(&cx);
461-
let o = memtransfer_types;
462477
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
463-
generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
478+
generate_mapper_call(
479+
&mut builder,
480+
&cx,
481+
geps,
482+
memtransfer_types,
483+
begin_mapper_decl,
484+
fn_ty,
485+
num_args,
486+
s_ident_t,
487+
);
464488
let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
465489

466490
// Step 3)
@@ -485,7 +509,16 @@ pub(crate) fn gen_call_handling<'ll>(
485509

486510
// Step 4)
487511
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
488-
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
512+
generate_mapper_call(
513+
&mut builder,
514+
&cx,
515+
geps,
516+
memtransfer_types,
517+
end_mapper_decl,
518+
fn_ty,
519+
num_args,
520+
s_ident_t,
521+
);
489522

490523
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
491524

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,10 @@ fn codegen_autodiff<'ll, 'tcx>(
12431243
);
12441244
}
12451245

1246+
// Generates the LLVM code to offload a Rust function to a target device (e.g., GPU).
1247+
// For each kernel call, it generates the necessary globals (including metadata such as
1248+
// size and pass mode), manages memory mapping to and from the device, handles all
1249+
// data transfers, and launches the kernel on the target device.
12461250
fn codegen_offload<'ll, 'tcx>(
12471251
bx: &mut Builder<'_, 'll, 'tcx>,
12481252
tcx: TyCtxt<'tcx>,
@@ -1282,27 +1286,16 @@ fn codegen_offload<'ll, 'tcx>(
12821286

12831287
let types = inputs.iter().map(|ty| cx.layout_of(*ty).llvm_type(cx)).collect::<Vec<_>>();
12841288

1285-
let (offload_sizes, memtransfer_types, region_id, offload_entry) =
1286-
crate::builder::gpu_offload::gen_define_handling(
1287-
cx,
1288-
offload_entry_ty,
1289-
&metadata,
1290-
&types,
1291-
&target_symbol,
1292-
);
1293-
1294-
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
1295-
crate::builder::gpu_offload::gen_call_handling(
1289+
let offload_data = crate::builder::gpu_offload::gen_define_handling(
12961290
cx,
1297-
bb,
1298-
offload_sizes,
1299-
offload_entry,
1300-
memtransfer_types,
1301-
region_id,
1302-
&args,
1303-
&types,
1291+
offload_entry_ty,
13041292
&metadata,
1293+
&types,
1294+
&target_symbol,
13051295
);
1296+
1297+
let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) };
1298+
crate::builder::gpu_offload::gen_call_handling(cx, bb, &offload_data, &args, &types, &metadata);
13061299
}
13071300

13081301
fn get_args_from_tuple<'ll, 'tcx>(

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,7 +2448,10 @@ unsafe extern "C" {
24482448

24492449
pub(crate) fn LLVMRustSetDataLayoutFromTargetMachine<'a>(M: &'a Module, TM: &'a TargetMachine);
24502450

2451+
pub(crate) fn LLVMRustPositionBuilderPastAllocas<'a>(B: &Builder<'a>, Fn: &'a Value);
24512452
pub(crate) fn LLVMRustPositionBuilderAtStart<'a>(B: &Builder<'a>, BB: &'a BasicBlock);
2453+
pub(crate) fn LLVMRustGetInsertPoint<'a>(B: &Builder<'a>) -> &'a Value;
2454+
pub(crate) fn LLVMRustRestoreInsertPoint<'a>(B: &Builder<'a>, IP: &'a Value);
24522455

24532456
pub(crate) fn LLVMRustSetModulePICLevel(M: &Module);
24542457
pub(crate) fn LLVMRustSetModulePIELevel(M: &Module);

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,39 @@ extern "C" void LLVMRustPositionAfter(LLVMBuilderRef B, LLVMValueRef Instr) {
13781378
}
13791379
}
13801380

1381+
extern "C" LLVMValueRef LLVMRustGetInsertPoint(LLVMBuilderRef B) {
1382+
llvm::IRBuilderBase &IRB = *unwrap(B);
1383+
1384+
llvm::IRBuilderBase::InsertPoint ip = IRB.saveIP();
1385+
llvm::BasicBlock *BB = ip.getBlock();
1386+
1387+
if (!BB)
1388+
return nullptr;
1389+
1390+
auto it = ip.getPoint();
1391+
1392+
if (it == BB->end())
1393+
return nullptr;
1394+
1395+
llvm::Instruction *I = &*it;
1396+
return wrap(I);
1397+
}
1398+
1399+
extern "C" void LLVMRustRestoreInsertPoint(LLVMBuilderRef B, LLVMValueRef Instr) {
1400+
llvm::IRBuilderBase &IRB = *unwrap(B);
1401+
1402+
if (!Instr) {
1403+
llvm::BasicBlock *BB = IRB.GetInsertBlock();
1404+
if (BB)
1405+
IRB.SetInsertPoint(BB);
1406+
return;
1407+
}
1408+
1409+
llvm::Instruction *I = unwrap<llvm::Instruction>(Instr);
1410+
IRB.SetInsertPoint(I);
1411+
}
1412+
1413+
13811414
extern "C" LLVMValueRef
13821415
LLVMRustGetFunctionCall(LLVMValueRef Fn, const char *Name, size_t NameLen) {
13831416
auto targetName = StringRef(Name, NameLen);

tests/codegen-llvm/gpu_offload/gpu_host.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ fn main() {
5151

5252
// CHECK: define{{( dso_local)?}} void @kernel_1(ptr noalias noundef align 4 dereferenceable(1024) %x)
5353
// CHECK-NEXT: start:
54-
// CHECK-NEXT: %dummy = load volatile ptr, ptr @.offload_sizes._kernel_1, align 8
55-
// CHECK-NEXT: %dummy1 = load volatile ptr, ptr @.offloading.entry._kernel_1, align 8
5654
// CHECK-NEXT: %EmptyDesc = alloca %struct.__tgt_bin_desc, align 8
5755
// CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8
5856
// CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8
5957
// CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8
6058
// CHECK-NEXT: %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
59+
// CHECK-NEXT: %dummy = load volatile ptr, ptr @.offload_sizes._kernel_1, align 8
60+
// CHECK-NEXT: %dummy1 = load volatile ptr, ptr @.offloading.entry._kernel_1, align 8
6161
// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) %EmptyDesc, i8 0, i64 32, i1 false)
6262
// CHECK-NEXT: call void @__tgt_register_lib(ptr nonnull %EmptyDesc)
6363
// CHECK-NEXT: call void @__tgt_init_all_rtls()

0 commit comments

Comments
 (0)