Skip to content

Commit e6f3994

Browse files
bjorn3eggyal
andauthored
Atomic hotswapping in JIT mode (#2786)
* Introduce new_got_entry and new_plt_entry functions * Return NonNull<*const u8> from get_got_address * Make GOT entry writes atomic * Defer GOT updates until relocations and protection Co-authored-by: Alan Egerton <eggyal@gmail.com>
1 parent 884a650 commit e6f3994

File tree

1 file changed

+97
-104
lines changed

1 file changed

+97
-104
lines changed

cranelift/jit/src/backend.rs

Lines changed: 97 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use std::ffi::CString;
2323
use std::io::Write;
2424
use std::ptr;
2525
use std::ptr::NonNull;
26+
use std::sync::atomic::{AtomicPtr, Ordering};
2627
use target_lexicon::PointerWidth;
2728
#[cfg(windows)]
2829
use winapi;
@@ -129,6 +130,15 @@ impl JITBuilder {
129130
}
130131
}
131132

133+
/// A pending update to the GOT.
134+
struct GotUpdate {
135+
/// The entry that is to be updated.
136+
entry: NonNull<AtomicPtr<u8>>,
137+
138+
/// The new value of the entry.
139+
ptr: *const u8,
140+
}
141+
132142
/// A `JITModule` implements `Module` and emits code and data into memory where it can be
133143
/// directly called and accessed.
134144
///
@@ -140,15 +150,18 @@ pub struct JITModule {
140150
libcall_names: Box<dyn Fn(ir::LibCall) -> String>,
141151
memory: MemoryHandle,
142152
declarations: ModuleDeclarations,
143-
function_got_entries: SecondaryMap<FuncId, Option<NonNull<*const u8>>>,
153+
function_got_entries: SecondaryMap<FuncId, Option<NonNull<AtomicPtr<u8>>>>,
144154
function_plt_entries: SecondaryMap<FuncId, Option<NonNull<[u8; 16]>>>,
145-
data_object_got_entries: SecondaryMap<DataId, Option<NonNull<*const u8>>>,
146-
libcall_got_entries: HashMap<ir::LibCall, NonNull<*const u8>>,
155+
data_object_got_entries: SecondaryMap<DataId, Option<NonNull<AtomicPtr<u8>>>>,
156+
libcall_got_entries: HashMap<ir::LibCall, NonNull<AtomicPtr<u8>>>,
147157
libcall_plt_entries: HashMap<ir::LibCall, NonNull<[u8; 16]>>,
148158
compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
149159
compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
150160
functions_to_finalize: Vec<FuncId>,
151161
data_objects_to_finalize: Vec<DataId>,
162+
163+
/// Updates to the GOT awaiting relocations to be made and region protections to be set
164+
pending_got_updates: Vec<GotUpdate>,
152165
}
153166

154167
/// A handle to allow freeing memory allocated by the `Module`.
@@ -180,54 +193,53 @@ impl JITModule {
180193
.or_else(|| lookup_with_dlsym(name))
181194
}
182195

183-
fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) {
196+
fn new_got_entry(&mut self, val: *const u8) -> NonNull<AtomicPtr<u8>> {
184197
let got_entry = self
185198
.memory
186199
.writable
187200
.allocate(
188-
std::mem::size_of::<*const u8>(),
189-
std::mem::align_of::<*const u8>().try_into().unwrap(),
201+
std::mem::size_of::<AtomicPtr<u8>>(),
202+
std::mem::align_of::<AtomicPtr<u8>>().try_into().unwrap(),
190203
)
191204
.unwrap()
192-
.cast::<*const u8>();
193-
self.function_got_entries[id] = Some(NonNull::new(got_entry).unwrap());
205+
.cast::<AtomicPtr<u8>>();
194206
unsafe {
195-
std::ptr::write(got_entry, val);
207+
std::ptr::write(got_entry, AtomicPtr::new(val as *mut _));
196208
}
209+
NonNull::new(got_entry).unwrap()
210+
}
211+
212+
fn new_plt_entry(&mut self, got_entry: NonNull<AtomicPtr<u8>>) -> NonNull<[u8; 16]> {
197213
let plt_entry = self
198214
.memory
199215
.code
200216
.allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT)
201217
.unwrap()
202218
.cast::<[u8; 16]>();
219+
unsafe {
220+
Self::write_plt_entry_bytes(plt_entry, got_entry);
221+
}
222+
NonNull::new(plt_entry).unwrap()
223+
}
224+
225+
fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) {
226+
let got_entry = self.new_got_entry(val);
227+
self.function_got_entries[id] = Some(got_entry);
228+
let plt_entry = self.new_plt_entry(got_entry);
203229
self.record_function_for_perf(
204-
plt_entry as *mut _,
230+
plt_entry.as_ptr().cast(),
205231
std::mem::size_of::<[u8; 16]>(),
206232
&format!("{}@plt", self.declarations.get_function_decl(id).name),
207233
);
208-
self.function_plt_entries[id] = Some(NonNull::new(plt_entry).unwrap());
209-
unsafe {
210-
Self::write_plt_entry_bytes(plt_entry, got_entry);
211-
}
234+
self.function_plt_entries[id] = Some(plt_entry);
212235
}
213236

214237
fn new_data_got_entry(&mut self, id: DataId, val: *const u8) {
215-
let got_entry = self
216-
.memory
217-
.writable
218-
.allocate(
219-
std::mem::size_of::<*const u8>(),
220-
std::mem::align_of::<*const u8>().try_into().unwrap(),
221-
)
222-
.unwrap()
223-
.cast::<*const u8>();
224-
self.data_object_got_entries[id] = Some(NonNull::new(got_entry).unwrap());
225-
unsafe {
226-
std::ptr::write(got_entry, val);
227-
}
238+
let got_entry = self.new_got_entry(val);
239+
self.data_object_got_entries[id] = Some(got_entry);
228240
}
229241

230-
unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: *mut *const u8) {
242+
unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: NonNull<AtomicPtr<u8>>) {
231243
assert!(
232244
cfg!(target_arch = "x86_64"),
233245
"PLT is currently only supported on x86_64"
@@ -236,7 +248,7 @@ impl JITModule {
236248
let mut plt_val = [
237249
0xff, 0x25, 0, 0, 0, 0, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b, 0x0f, 0x0b,
238250
];
239-
let what = got_ptr as isize - 4;
251+
let what = got_ptr.as_ptr() as isize - 4;
240252
let at = plt_ptr as isize + 2;
241253
plt_val[2..6].copy_from_slice(&i32::to_ne_bytes(i32::try_from(what - at).unwrap()));
242254
std::ptr::write(plt_ptr, plt_val);
@@ -289,32 +301,25 @@ impl JITModule {
289301
///
290302
/// Panics if there's no entry in the table for the given function.
291303
pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 {
292-
unsafe { *self.function_got_entries[func_id].unwrap().as_ptr() }
304+
let got_entry = self.function_got_entries[func_id].unwrap();
305+
unsafe { got_entry.as_ref() }.load(Ordering::SeqCst)
293306
}
294307

295-
fn get_got_address(&self, name: &ir::ExternalName) -> *const u8 {
308+
fn get_got_address(&self, name: &ir::ExternalName) -> NonNull<AtomicPtr<u8>> {
296309
match *name {
297310
ir::ExternalName::User { .. } => {
298311
if ModuleDeclarations::is_function(name) {
299312
let func_id = FuncId::from_name(name);
300-
self.function_got_entries[func_id]
301-
.unwrap()
302-
.as_ptr()
303-
.cast::<u8>()
313+
self.function_got_entries[func_id].unwrap()
304314
} else {
305315
let data_id = DataId::from_name(name);
306-
self.data_object_got_entries[data_id]
307-
.unwrap()
308-
.as_ptr()
309-
.cast::<u8>()
316+
self.data_object_got_entries[data_id].unwrap()
310317
}
311318
}
312-
ir::ExternalName::LibCall(ref libcall) => self
319+
ir::ExternalName::LibCall(ref libcall) => *self
313320
.libcall_got_entries
314321
.get(libcall)
315-
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
316-
.as_ptr()
317-
.cast::<u8>(),
322+
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)),
318323
_ => panic!("invalid ExternalName {}", name),
319324
}
320325
}
@@ -406,7 +411,7 @@ impl JITModule {
406411
.expect("function must be compiled before it can be finalized");
407412
func.perform_relocations(
408413
|name| self.get_address(name),
409-
|name| self.get_got_address(name),
414+
|name| self.get_got_address(name).as_ptr().cast(),
410415
|name| self.get_plt_address(name),
411416
);
412417
}
@@ -419,14 +424,18 @@ impl JITModule {
419424
.expect("data object must be compiled before it can be finalized");
420425
data.perform_relocations(
421426
|name| self.get_address(name),
422-
|name| self.get_got_address(name),
427+
|name| self.get_got_address(name).as_ptr().cast(),
423428
|name| self.get_plt_address(name),
424429
);
425430
}
426431

427432
// Now that we're done patching, prepare the memory for execution!
428433
self.memory.readonly.set_readonly();
429434
self.memory.code.set_readable_and_executable();
435+
436+
for update in self.pending_got_updates.drain(..) {
437+
unsafe { update.entry.as_ref() }.store(update.ptr as *mut _, Ordering::SeqCst);
438+
}
430439
}
431440

432441
/// Create a new `JITModule`.
@@ -438,33 +447,38 @@ impl JITModule {
438447
);
439448
}
440449

441-
let mut memory = MemoryHandle {
442-
code: Memory::new(),
443-
readonly: Memory::new(),
444-
writable: Memory::new(),
450+
let mut module = Self {
451+
isa: builder.isa,
452+
hotswap_enabled: builder.hotswap_enabled,
453+
symbols: builder.symbols,
454+
libcall_names: builder.libcall_names,
455+
memory: MemoryHandle {
456+
code: Memory::new(),
457+
readonly: Memory::new(),
458+
writable: Memory::new(),
459+
},
460+
declarations: ModuleDeclarations::default(),
461+
function_got_entries: SecondaryMap::new(),
462+
function_plt_entries: SecondaryMap::new(),
463+
data_object_got_entries: SecondaryMap::new(),
464+
libcall_got_entries: HashMap::new(),
465+
libcall_plt_entries: HashMap::new(),
466+
compiled_functions: SecondaryMap::new(),
467+
compiled_data_objects: SecondaryMap::new(),
468+
functions_to_finalize: Vec::new(),
469+
data_objects_to_finalize: Vec::new(),
470+
pending_got_updates: Vec::new(),
445471
};
446472

447-
let mut libcall_got_entries = HashMap::new();
448-
let mut libcall_plt_entries = HashMap::new();
449-
450473
// Pre-create a GOT and PLT entry for each libcall.
451-
let all_libcalls = if builder.isa.flags().is_pic() {
474+
let all_libcalls = if module.isa.flags().is_pic() {
452475
ir::LibCall::all_libcalls()
453476
} else {
454477
&[] // Not PIC, so no GOT and PLT entries necessary
455478
};
456479
for &libcall in all_libcalls {
457-
let got_entry = memory
458-
.writable
459-
.allocate(
460-
std::mem::size_of::<*const u8>(),
461-
std::mem::align_of::<*const u8>().try_into().unwrap(),
462-
)
463-
.unwrap()
464-
.cast::<*const u8>();
465-
libcall_got_entries.insert(libcall, NonNull::new(got_entry).unwrap());
466-
let sym = (builder.libcall_names)(libcall);
467-
let addr = if let Some(addr) = builder
480+
let sym = (module.libcall_names)(libcall);
481+
let addr = if let Some(addr) = module
468482
.symbols
469483
.get(&sym)
470484
.copied()
@@ -474,37 +488,13 @@ impl JITModule {
474488
} else {
475489
continue;
476490
};
477-
unsafe {
478-
std::ptr::write(got_entry, addr);
479-
}
480-
let plt_entry = memory
481-
.code
482-
.allocate(std::mem::size_of::<[u8; 16]>(), EXECUTABLE_DATA_ALIGNMENT)
483-
.unwrap()
484-
.cast::<[u8; 16]>();
485-
libcall_plt_entries.insert(libcall, NonNull::new(plt_entry).unwrap());
486-
unsafe {
487-
Self::write_plt_entry_bytes(plt_entry, got_entry);
488-
}
491+
let got_entry = module.new_got_entry(addr);
492+
module.libcall_got_entries.insert(libcall, got_entry);
493+
let plt_entry = module.new_plt_entry(got_entry);
494+
module.libcall_plt_entries.insert(libcall, plt_entry);
489495
}
490496

491-
Self {
492-
isa: builder.isa,
493-
hotswap_enabled: builder.hotswap_enabled,
494-
symbols: builder.symbols,
495-
libcall_names: builder.libcall_names,
496-
memory,
497-
declarations: ModuleDeclarations::default(),
498-
function_got_entries: SecondaryMap::new(),
499-
function_plt_entries: SecondaryMap::new(),
500-
data_object_got_entries: SecondaryMap::new(),
501-
libcall_got_entries,
502-
libcall_plt_entries,
503-
compiled_functions: SecondaryMap::new(),
504-
compiled_data_objects: SecondaryMap::new(),
505-
functions_to_finalize: Vec::new(),
506-
data_objects_to_finalize: Vec::new(),
507-
}
497+
module
508498
}
509499

510500
/// Allow a single future `define_function` on a previously defined function. This allows for
@@ -682,9 +672,10 @@ impl Module for JITModule {
682672
});
683673

684674
if self.isa.flags().is_pic() {
685-
unsafe {
686-
std::ptr::write(self.function_got_entries[id].unwrap().as_ptr(), ptr);
687-
}
675+
self.pending_got_updates.push(GotUpdate {
676+
entry: self.function_got_entries[id].unwrap(),
677+
ptr,
678+
})
688679
}
689680

690681
if self.hotswap_enabled {
@@ -704,7 +695,7 @@ impl Module for JITModule {
704695
.cast::<u8>(),
705696
_ => panic!("invalid ExternalName {}", name),
706697
},
707-
|name| self.get_got_address(name),
698+
|name| self.get_got_address(name).as_ptr().cast(),
708699
|name| self.get_plt_address(name),
709700
);
710701
} else {
@@ -754,9 +745,10 @@ impl Module for JITModule {
754745
});
755746

756747
if self.isa.flags().is_pic() {
757-
unsafe {
758-
std::ptr::write(self.function_got_entries[id].unwrap().as_ptr(), ptr);
759-
}
748+
self.pending_got_updates.push(GotUpdate {
749+
entry: self.function_got_entries[id].unwrap(),
750+
ptr,
751+
})
760752
}
761753

762754
if self.hotswap_enabled {
@@ -765,7 +757,7 @@ impl Module for JITModule {
765757
.unwrap()
766758
.perform_relocations(
767759
|name| unreachable!("non GOT or PLT relocation in function {} to {}", id, name),
768-
|name| self.get_got_address(name),
760+
|name| self.get_got_address(name).as_ptr().cast(),
769761
|name| self.get_plt_address(name),
770762
);
771763
} else {
@@ -836,9 +828,10 @@ impl Module for JITModule {
836828
self.compiled_data_objects[id] = Some(CompiledBlob { ptr, size, relocs });
837829
self.data_objects_to_finalize.push(id);
838830
if self.isa.flags().is_pic() {
839-
unsafe {
840-
std::ptr::write(self.data_object_got_entries[id].unwrap().as_ptr(), ptr);
841-
}
831+
self.pending_got_updates.push(GotUpdate {
832+
entry: self.data_object_got_entries[id].unwrap(),
833+
ptr,
834+
})
842835
}
843836

844837
Ok(())

0 commit comments

Comments
 (0)