@@ -170,6 +170,14 @@ impl KernelArgsTy {
170170 }
171171}
172172
173+ // Contains LLVM values needed to manage offloading for a single kernel.
174+ pub ( crate ) struct OffloadKernelData < ' ll > {
175+ pub offload_sizes : & ' ll llvm:: Value ,
176+ pub memtransfer_types : & ' ll llvm:: Value ,
177+ pub region_id : & ' ll llvm:: Value ,
178+ pub offload_entry : & ' ll llvm:: Value ,
179+ }
180+
173181fn gen_tgt_data_mappers < ' ll > (
174182 cx : & ' ll SimpleCx < ' _ > ,
175183) -> ( & ' ll llvm:: Value , & ' ll llvm:: Value , & ' ll llvm:: Value , & ' ll llvm:: Type ) {
@@ -238,7 +246,7 @@ pub(crate) fn gen_define_handling<'ll>(
238246 metadata : & [ OffloadMetadata ] ,
239247 types : & [ & Type ] ,
240248 symbol : & str ,
241- ) -> ( & ' ll llvm :: Value , & ' ll llvm :: Value , & ' ll llvm :: Value , & ' ll llvm :: Value ) {
249+ ) -> OffloadKernelData < ' ll > {
242250 // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
243251 // reference) types.
244252 let ptr_meta = types. iter ( ) . zip ( metadata) . filter_map ( |( & x, meta) | match cx. type_kind ( x) {
@@ -290,7 +298,7 @@ pub(crate) fn gen_define_handling<'ll>(
290298 let c_section_name = CString :: new ( "llvm_offload_entries" ) . unwrap ( ) ;
291299 llvm:: set_section ( offload_entry, & c_section_name) ;
292300
293- ( offload_sizes, memtransfer_types, region_id, offload_entry)
301+ OffloadKernelData { offload_sizes, memtransfer_types, region_id, offload_entry }
294302}
295303
296304fn declare_offload_fn < ' ll > (
@@ -330,14 +338,13 @@ fn declare_offload_fn<'ll>(
330338pub ( crate ) fn gen_call_handling < ' ll > (
331339 cx : & SimpleCx < ' ll > ,
332340 bb : & BasicBlock ,
333- offload_sizes : & ' ll llvm:: Value ,
334- offload_entry : & ' ll llvm:: Value ,
335- memtransfer_types : & ' ll llvm:: Value ,
336- region_id : & ' ll llvm:: Value ,
341+ offload_data : & OffloadKernelData < ' ll > ,
337342 args : & [ & ' ll Value ] ,
338343 types : & [ & Type ] ,
339344 metadata : & [ OffloadMetadata ] ,
340345) {
346+ let OffloadKernelData { offload_sizes, offload_entry, memtransfer_types, region_id } =
347+ offload_data;
341348 let ( tgt_decl, tgt_target_kernel_ty) = generate_launcher ( & cx) ;
342349 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
343350 let tptr = cx. type_ptr ( ) ;
@@ -351,7 +358,11 @@ pub(crate) fn gen_call_handling<'ll>(
351358
352359 let mut builder = SBuilder :: build ( cx, bb) ;
353360
354- // prevent these globals from being optimized away
361+ let num_args = types. len ( ) as u64 ;
362+ let ip = unsafe { llvm:: LLVMRustGetInsertPoint ( & builder. llbuilder ) } ;
363+
364+ // FIXME(Sa4dUs): dummy loads are a temp workaround, we should find a proper way to prevent these
365+ // variables from being optimized away
355366 for val in [ offload_sizes, offload_entry] {
356367 unsafe {
357368 let dummy = llvm:: LLVMBuildLoad2 (
@@ -364,11 +375,13 @@ pub(crate) fn gen_call_handling<'ll>(
364375 }
365376 }
366377
367- let num_args = types. len ( ) as u64 ;
368-
369378 // Step 0)
370379 // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
371380 // %6 = alloca %struct.__tgt_bin_desc, align 8
381+ let llfn = unsafe { llvm:: LLVMGetBasicBlockParent ( bb) } ;
382+ unsafe {
383+ llvm:: LLVMRustPositionBuilderPastAllocas ( & builder. llbuilder , llfn) ;
384+ }
372385 let tgt_bin_desc_alloca = builder. direct_alloca ( tgt_bin_desc, Align :: EIGHT , "EmptyDesc" ) ;
373386
374387 let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
@@ -384,6 +397,9 @@ pub(crate) fn gen_call_handling<'ll>(
384397 let a5 = builder. direct_alloca ( tgt_kernel_decl, Align :: EIGHT , "kernel_args" ) ;
385398
386399 // Step 1)
400+ unsafe {
401+ llvm:: LLVMRustRestoreInsertPoint ( & builder. llbuilder , ip) ;
402+ }
387403 builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
388404
389405 // Now we allocate once per function param, a copy to be passed to one of our maps.
@@ -458,9 +474,17 @@ pub(crate) fn gen_call_handling<'ll>(
458474
459475 // Step 2)
460476 let s_ident_t = generate_at_one ( & cx) ;
461- let o = memtransfer_types;
462477 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
463- generate_mapper_call ( & mut builder, & cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t) ;
478+ generate_mapper_call (
479+ & mut builder,
480+ & cx,
481+ geps,
482+ memtransfer_types,
483+ begin_mapper_decl,
484+ fn_ty,
485+ num_args,
486+ s_ident_t,
487+ ) ;
464488 let values = KernelArgsTy :: new ( & cx, num_args, memtransfer_types, geps) ;
465489
466490 // Step 3)
@@ -485,7 +509,16 @@ pub(crate) fn gen_call_handling<'ll>(
485509
486510 // Step 4)
487511 let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
488- generate_mapper_call ( & mut builder, & cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t) ;
512+ generate_mapper_call (
513+ & mut builder,
514+ & cx,
515+ geps,
516+ memtransfer_types,
517+ end_mapper_decl,
518+ fn_ty,
519+ num_args,
520+ s_ident_t,
521+ ) ;
489522
490523 builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
491524
0 commit comments