Skip to content

Commit e79f1fa

Browse files
committed
gpu offload memory-transfer mvp
1 parent fa72869 commit e79f1fa

File tree

13 files changed

+837
-5
lines changed

13 files changed

+837
-5
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ pub(crate) fn run_pass_manager(
653653
// We then run the llvm_optimize function a second time, to optimize the code which we generated
654654
// in the enzyme differentiation pass.
655655
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
656+
let enable_gpu = config.offload.contains(&config::Offload::Enable);
656657
let stage = if thin {
657658
write::AutodiffStage::PreAD
658659
} else {
@@ -667,6 +668,13 @@ pub(crate) fn run_pass_manager(
667668
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
668669
}
669670

671+
if cfg!(llvm_enzyme) && enable_gpu && !thin {
672+
dbg!(&enable_gpu);
673+
let cx =
674+
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
675+
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);
676+
}
677+
670678
if cfg!(llvm_enzyme) && enable_ad && !thin {
671679
let cx =
672680
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::ops::Deref;
33
use std::{iter, ptr};
44

55
pub(crate) mod autodiff;
6+
pub(crate) mod gpu_device;
7+
pub(crate) mod gpu_offload;
68

79
use libc::{c_char, c_uint, size_t};
810
use rustc_abi as abi;
@@ -117,6 +119,70 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
117119
}
118120
bx
119121
}
122+
123+
pub(crate) fn my_alloca2(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value {
124+
let val = unsafe {
125+
let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED);
126+
llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint);
127+
// Cast to default addrspace if necessary
128+
llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED)
129+
};
130+
if name != "" {
131+
let name = std::ffi::CString::new(name).unwrap();
132+
llvm::set_value_name(val, &name.as_bytes());
133+
}
134+
val
135+
}
136+
137+
pub(crate) fn inbounds_gep(
138+
&mut self,
139+
ty: &'ll Type,
140+
ptr: &'ll Value,
141+
indices: &[&'ll Value],
142+
) -> &'ll Value {
143+
unsafe {
144+
llvm::LLVMBuildGEPWithNoWrapFlags(
145+
self.llbuilder,
146+
ty,
147+
ptr,
148+
indices.as_ptr(),
149+
indices.len() as c_uint,
150+
UNNAMED,
151+
GEPNoWrapFlags::InBounds,
152+
)
153+
}
154+
}
155+
156+
pub(crate) fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value {
157+
debug!("Store {:?} -> {:?}", val, ptr);
158+
assert_eq!(self.cx.type_kind(self.cx.val_ty(ptr)), TypeKind::Pointer);
159+
unsafe {
160+
let store = llvm::LLVMBuildStore(self.llbuilder, val, ptr);
161+
llvm::LLVMSetAlignment(store, align.bytes() as c_uint);
162+
store
163+
}
164+
}
165+
166+
pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value {
167+
unsafe {
168+
let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED);
169+
llvm::LLVMSetAlignment(load, align.bytes() as c_uint);
170+
load
171+
}
172+
}
173+
174+
fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) {
175+
unsafe {
176+
llvm::LLVMRustBuildMemSet(
177+
self.llbuilder,
178+
ptr,
179+
align.bytes() as c_uint,
180+
fill_byte,
181+
size,
182+
false,
183+
);
184+
}
185+
}
120186
}
121187

122188
/// Empty string, to be used where LLVM expects an instruction name, indicating

0 commit comments

Comments
 (0)