Skip to content

Commit

Permalink
wrap JitState in Mutex and fix an incorrect use
Browse files Browse the repository at this point in the history
  • Loading branch information
copy committed Dec 20, 2024
1 parent 92126bf commit 20dfa31
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 46 deletions.
36 changes: 22 additions & 14 deletions src/rust/cpu/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,15 @@ pub unsafe fn do_page_walk(
}

let is_in_mapped_range = in_mapped_range(high);
let has_code = !is_in_mapped_range && jit::jit_page_has_code(Page::page_of(high));
let has_code = if side_effects {
!is_in_mapped_range && jit::jit_page_has_code(Page::page_of(high))
}
else {
// If side_effects is false, don't call into jit::jit_page_has_code. This value is not used
// anyway (we only get here by translate_address_read_no_side_effects, which only uses the
// address part)
true
};
let info_bits = TLB_VALID
| if for_writing { 0 } else { TLB_READONLY }
| if allow_user { 0 } else { TLB_NO_USER }
Expand Down Expand Up @@ -3597,7 +3605,7 @@ pub unsafe fn safe_write_slow_jit(
}
else {
if !can_skip_dirty_page {
jit::jit_dirty_page(jit::get_jit_state(), Page::page_of(addr_low));
jit::jit_dirty_page(Page::page_of(addr_low));
}
((addr_low as i32 + memory::mem8 as i32) ^ addr) & !0xFFF
}
Expand Down Expand Up @@ -3636,7 +3644,7 @@ pub unsafe fn safe_write8(addr: i32, value: i32) -> OrPageFault<()> {
}
else {
if !can_skip_dirty_page {
jit::jit_dirty_page(jit::get_jit_state(), Page::page_of(phys_addr));
jit::jit_dirty_page(Page::page_of(phys_addr));
}
else {
dbg_assert!(!jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
Expand All @@ -3656,7 +3664,7 @@ pub unsafe fn safe_write16(addr: i32, value: i32) -> OrPageFault<()> {
}
else {
if !can_skip_dirty_page {
jit::jit_dirty_page(jit::get_jit_state(), Page::page_of(phys_addr));
jit::jit_dirty_page(Page::page_of(phys_addr));
}
else {
dbg_assert!(!jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
Expand All @@ -3680,7 +3688,7 @@ pub unsafe fn safe_write32(addr: i32, value: i32) -> OrPageFault<()> {
}
else {
if !can_skip_dirty_page {
jit::jit_dirty_page(jit::get_jit_state(), Page::page_of(phys_addr));
jit::jit_dirty_page(Page::page_of(phys_addr));
}
else {
dbg_assert!(!jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
Expand All @@ -3703,7 +3711,7 @@ pub unsafe fn safe_write64(addr: i32, value: u64) -> OrPageFault<()> {
}
else {
if !can_skip_dirty_page {
jit::jit_dirty_page(jit::get_jit_state(), Page::page_of(phys_addr));
jit::jit_dirty_page(Page::page_of(phys_addr));
}
else {
dbg_assert!(!jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
Expand All @@ -3727,7 +3735,7 @@ pub unsafe fn safe_write128(addr: i32, value: reg128) -> OrPageFault<()> {
}
else {
if !can_skip_dirty_page {
jit::jit_dirty_page(jit::get_jit_state(), Page::page_of(phys_addr));
jit::jit_dirty_page(Page::page_of(phys_addr));
}
else {
dbg_assert!(!jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
Expand All @@ -3749,10 +3757,10 @@ pub unsafe fn safe_read_write8(addr: i32, instruction: &dyn Fn(i32) -> i32) {
}
else {
if !can_skip_dirty_page {
::jit::jit_dirty_page(::jit::get_jit_state(), Page::page_of(phys_addr));
jit::jit_dirty_page(Page::page_of(phys_addr));
}
else {
dbg_assert!(!::jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
dbg_assert!(!jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
}
memory::write8_no_mmap_or_dirty_check(phys_addr, value);
}
Expand All @@ -3775,10 +3783,10 @@ pub unsafe fn safe_read_write16(addr: i32, instruction: &dyn Fn(i32) -> i32) {
}
else {
if !can_skip_dirty_page {
::jit::jit_dirty_page(::jit::get_jit_state(), Page::page_of(phys_addr));
jit::jit_dirty_page(Page::page_of(phys_addr));
}
else {
dbg_assert!(!::jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
dbg_assert!(!jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
}
memory::write16_no_mmap_or_dirty_check(phys_addr, value);
};
Expand All @@ -3803,10 +3811,10 @@ pub unsafe fn safe_read_write32(addr: i32, instruction: &dyn Fn(i32) -> i32) {
}
else {
if !can_skip_dirty_page {
::jit::jit_dirty_page(::jit::get_jit_state(), Page::page_of(phys_addr));
jit::jit_dirty_page(Page::page_of(phys_addr));
}
else {
dbg_assert!(!::jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
dbg_assert!(!jit::jit_page_has_code(Page::page_of(phys_addr as u32)));
}
memory::write32_no_mmap_or_dirty_check(phys_addr, value);
};
Expand Down Expand Up @@ -4381,7 +4389,7 @@ pub unsafe fn reset_cpu() {

update_state_flags();

jit::jit_clear_cache(jit::get_jit_state());
jit::jit_clear_cache_js();
}

#[no_mangle]
Expand Down
7 changes: 4 additions & 3 deletions src/rust/cpu/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod ext {
use cpu::cpu::reg128;
use cpu::global_pointers::memory_size;
use cpu::vga;
use jit;
use page::Page;

use std::alloc;
Expand Down Expand Up @@ -172,7 +173,7 @@ pub unsafe fn write8(addr: u32, value: i32) {
mmap_write8(addr, value & 0xFF);
}
else {
::jit::jit_dirty_page(::jit::get_jit_state(), Page::page_of(addr));
jit::jit_dirty_page(Page::page_of(addr));
write8_no_mmap_or_dirty_check(addr, value);
};
}
Expand All @@ -187,7 +188,7 @@ pub unsafe fn write16(addr: u32, value: i32) {
mmap_write16(addr, value & 0xFFFF);
}
else {
::jit::jit_dirty_cache_small(addr, addr + 2);
jit::jit_dirty_cache_small(addr, addr + 2);
write16_no_mmap_or_dirty_check(addr, value);
};
}
Expand All @@ -201,7 +202,7 @@ pub unsafe fn write32(addr: u32, value: i32) {
mmap_write32(addr, value);
}
else {
::jit::jit_dirty_cache_small(addr, addr + 4);
jit::jit_dirty_cache_small(addr, addr + 4);
write32_no_mmap_or_dirty_check(addr, value);
};
}
Expand Down
3 changes: 2 additions & 1 deletion src/rust/cpu/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use cpu::memory::{
memset_no_mmap_or_dirty_check, read16_no_mmap_check, read32_no_mmap_check, read8_no_mmap_check,
write16_no_mmap_or_dirty_check, write32_no_mmap_or_dirty_check, write8_no_mmap_or_dirty_check,
};
use jit;
use page::Page;

fn count_until_end_of_page(direction: i32, size: i32, addr: u32) -> u32 {
Expand Down Expand Up @@ -248,7 +249,7 @@ unsafe fn string_instruction(
dbg_assert!(count_until_end_of_page > 0);

if !skip_dirty_page {
::jit::jit_dirty_page(::jit::get_jit_state(), Page::page_of(phys_dst));
jit::jit_dirty_page(Page::page_of(phys_dst));
}

let mut rep_cmp_finished = false;
Expand Down
71 changes: 43 additions & 28 deletions src/rust/jit.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::iter::FromIterator;
use std::mem;
use std::mem::{self, MaybeUninit};
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;
use std::sync::{Mutex, MutexGuard};

use analysis::AnalysisType;
use codegen;
Expand Down Expand Up @@ -82,18 +84,27 @@ pub const CHECK_JIT_STATE_INVARIANTS: bool = false;

const MAX_INSTRUCTION_LENGTH: u32 = 16;

#[allow(non_upper_case_globals)]
static mut jit_state: NonNull<JitState> =
unsafe { NonNull::new_unchecked(mem::align_of::<JitState>() as *mut _) };
static JIT_STATE: Mutex<MaybeUninit<JitState>> = Mutex::new(MaybeUninit::uninit());
fn get_jit_state() -> JitStateRef { JitStateRef(JIT_STATE.try_lock().unwrap()) }

pub fn get_jit_state() -> &'static mut JitState { unsafe { jit_state.as_mut() } }
struct JitStateRef(MutexGuard<'static, MaybeUninit<JitState>>);

impl Deref for JitStateRef {
type Target = JitState;
fn deref(&self) -> &Self::Target { unsafe { self.0.assume_init_ref() } }
}
impl DerefMut for JitStateRef {
fn deref_mut(&mut self) -> &mut Self::Target { unsafe { self.0.assume_init_mut() } }
}

#[no_mangle]
pub fn rust_init() {
dbg_assert!(std::mem::size_of::<[Option<NonNull<cpu::Code>>; 0x100000]>() == 0x100000 * 4);

let x = Box::new(JitState::create_and_initialise());
unsafe { jit_state = NonNull::new(Box::into_raw(x)).unwrap() }
let _ = JIT_STATE
.try_lock()
.unwrap()
.write(JitState::create_and_initialise());

use std::panic;

Expand All @@ -114,7 +125,7 @@ enum CompilingPageState {
CompilingWritten,
}

pub struct JitState {
struct JitState {
wasm_builder: WasmBuilder,

// as an alternative to HashSet, we could use a bitmap of 4096 bits here
Expand All @@ -127,7 +138,7 @@ pub struct JitState {
compiling: Option<(WasmTableIndex, CompilingPageState)>,
}

pub fn check_jit_state_invariants(ctx: &mut JitState) {
fn check_jit_state_invariants(ctx: &mut JitState) {
if !CHECK_JIT_STATE_INVARIANTS {
return;
}
Expand Down Expand Up @@ -1011,7 +1022,7 @@ pub fn codegen_finalize_finished(
phys_addr: u32,
state_flags: CachedStateFlags,
) {
let ctx = get_jit_state();
let mut ctx = get_jit_state();

dbg_assert!(wasm_table_index != WasmTableIndex(0));

Expand All @@ -1029,8 +1040,8 @@ pub fn codegen_finalize_finished(
dbg_assert!(wasm_table_index == in_progress_wasm_table_index);

profiler::stat_increment(stat::INVALIDATE_MODULE_WRITTEN_WHILE_COMPILED);
free_wasm_table_index(ctx, wasm_table_index);
check_jit_state_invariants(ctx);
free_wasm_table_index(&mut ctx, wasm_table_index);
check_jit_state_invariants(&mut ctx);
return;
},
Some((in_progress_wasm_table_index, CompilingPageState::Compiling { pages })) => {
Expand Down Expand Up @@ -1083,10 +1094,10 @@ pub fn codegen_finalize_finished(

dbg_log!("unused after overwrite {}", index.to_u16());
profiler::stat_increment(stat::INVALIDATE_MODULE_UNUSED_AFTER_OVERWRITE);
free_wasm_table_index(ctx, index);
free_wasm_table_index(&mut ctx, index);
}

check_jit_state_invariants(ctx);
check_jit_state_invariants(&mut ctx);
}

pub fn update_tlb_code(virt_page: Page, phys_page: Page) {
Expand Down Expand Up @@ -2090,7 +2101,8 @@ pub fn jit_increase_hotness_and_maybe_compile(
return;
}

let ctx = get_jit_state();
let mut ctx = get_jit_state();
let is_compiling = ctx.compiling.is_some();
let page = Page::page_of(phys_address);
let (hotness, entry_points) = ctx.entry_points.entry(page).or_insert_with(|| {
cpu::tlb_set_has_code(page, true);
Expand All @@ -2104,18 +2116,18 @@ pub fn jit_increase_hotness_and_maybe_compile(

*hotness += heat;
if *hotness >= JIT_THRESHOLD {
if ctx.compiling.is_some() {
if is_compiling {
return;
}
// only try generating if we're in the correct address space
if cpu::translate_address_read_no_side_effects(virt_address) == Ok(phys_address) {
*hotness = 0;
jit_analyze_and_generate(ctx, virt_address, phys_address, cs_offset, state_flags)
jit_analyze_and_generate(&mut ctx, virt_address, phys_address, cs_offset, state_flags)
}
else {
profiler::stat_increment(stat::COMPILE_WRONG_ADDRESS_SPACE);
}
};
}
}

fn free_wasm_table_index(ctx: &mut JitState, wasm_table_index: WasmTableIndex) {
Expand Down Expand Up @@ -2164,7 +2176,7 @@ fn free_wasm_table_index(ctx: &mut JitState, wasm_table_index: WasmTableIndex) {
}

/// Register a write in this page: Delete all present code
pub fn jit_dirty_page(ctx: &mut JitState, page: Page) {
fn jit_dirty_page_ctx(ctx: &mut JitState, page: Page) {
let mut did_have_code = false;

if let Some(PageInfo {
Expand Down Expand Up @@ -2269,32 +2281,35 @@ pub fn jit_dirty_cache(start_addr: u32, end_addr: u32) {
let end_page = Page::page_of(end_addr - 1);

for page in start_page.to_u32()..end_page.to_u32() + 1 {
jit_dirty_page(get_jit_state(), Page::page_of(page << 12));
jit_dirty_page_ctx(&mut get_jit_state(), Page::page_of(page << 12));
}
}

#[no_mangle]
pub fn jit_dirty_page(page: Page) { jit_dirty_page_ctx(&mut get_jit_state(), page) }

/// dirty pages in the range of start_addr and end_addr, which must span at most two pages
pub fn jit_dirty_cache_small(start_addr: u32, end_addr: u32) {
dbg_assert!(start_addr < end_addr);

let start_page = Page::page_of(start_addr);
let end_page = Page::page_of(end_addr - 1);

let ctx = get_jit_state();
jit_dirty_page(ctx, start_page);
let mut ctx = get_jit_state();
jit_dirty_page_ctx(&mut ctx, start_page);

// Note: This can't happen when paging is enabled, as writes across
// boundaries are split up on two pages
if start_page != end_page {
dbg_assert!(start_page.to_u32() + 1 == end_page.to_u32());
jit_dirty_page(ctx, end_page);
jit_dirty_page_ctx(&mut ctx, end_page);
}
}

#[no_mangle]
pub fn jit_clear_cache_js() { jit_clear_cache(get_jit_state()) }
pub fn jit_clear_cache_js() { jit_clear_cache(&mut get_jit_state()) }

pub fn jit_clear_cache(ctx: &mut JitState) {
fn jit_clear_cache(ctx: &mut JitState) {
let mut pages_with_code = HashSet::new();

for &p in ctx.entry_points.keys() {
Expand All @@ -2305,13 +2320,13 @@ pub fn jit_clear_cache(ctx: &mut JitState) {
}

for page in pages_with_code {
jit_dirty_page(ctx, page);
jit_dirty_page_ctx(ctx, page);
}
}

pub fn jit_page_has_code(page: Page) -> bool { jit_page_has_code_ctx(get_jit_state(), page) }
pub fn jit_page_has_code(page: Page) -> bool { jit_page_has_code_ctx(&mut get_jit_state(), page) }

pub fn jit_page_has_code_ctx(ctx: &mut JitState, page: Page) -> bool {
fn jit_page_has_code_ctx(ctx: &mut JitState, page: Page) -> bool {
ctx.pages.contains_key(&page) || ctx.entry_points.contains_key(&page)
}

Expand Down

0 comments on commit 20dfa31

Please sign in to comment.