Skip to content

Commit

Permalink
Create a type-safe(r) wrapper around page tables in stage0
Browse files Browse the repository at this point in the history
The `PageTable` struct from `x86_64` doesn't do any semantic checking of
how you set up your page table beyond the fact that the physical address
is valid.

This CR creates a specific `set_nested_table` method for setting up page
table hierarchies. It's still not _entirely_ type-safe -- we need to
ensure lifetimes make sense and wrap things in `Pin` -- but it's a step
in the direction of making page table ops safer.

Note that the TDX crate does a lot of ops on the `PageTable` structs
from the `x86_64` crate directly; that'll need to be cleaned up at some
point as well.

Bug: 377899703
Change-Id: Ibdcaa1ad096b99aafcb4452c0366e9aa3a54aba6
  • Loading branch information
andrisaar committed Nov 8, 2024
1 parent 3ccb4bf commit bf97164
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 86 deletions.
4 changes: 2 additions & 2 deletions stage0/src/hal/base/mmio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ use x86_64::{
};

use super::Base;
use crate::paging::{PageEncryption, PageTableEntry, PAGE_TABLE_REFS};
use crate::paging::{page_table_level::PT, PageEncryption, PageTableEntry, PAGE_TABLE_REFS};

pub struct Mmio<S: PageSize> {
pub base_address: PhysAddr,
layout: Layout,
mmio_memory: VirtAddr,
old_pte: PageTableEntry,
old_pte: PageTableEntry<PT>,
phantom: core::marker::PhantomData<S>,
}

Expand Down
188 changes: 141 additions & 47 deletions stage0/src/paging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

use alloc::boxed::Box;
use core::{
marker::PhantomData,
ops::{Index, IndexMut},
ptr::addr_of_mut,
};
Expand All @@ -35,143 +36,237 @@ use crate::{hal::PageAssignment, BootAllocator, Platform, BOOT_ALLOC};

/// The root page-map level 4 table coverting virtual memory ranges 0..128TiB
/// and (16EiB-128TiB)..16EiB.
pub static mut PML4: PageTable = PageTable::new();
pub static mut PML4: PageTable<page_table_level::PML4> = PageTable::new();
/// The page-directory pointer table covering virtual memory range 0..512GiB.
pub static mut PDPT: PageTable = PageTable::new();
pub static mut PDPT: PageTable<page_table_level::PDPT> = PageTable::new();
/// The page directory covering virtual memory range 0..1GiB.
pub static mut PD_0: PageTable = PageTable::new();
pub static mut PD_0: PageTable<page_table_level::PD> = PageTable::new();
/// The page directory covering virtual memory range 3..4GiB.
pub static mut PD_3: PageTable = PageTable::new();
pub static mut PD_3: PageTable<page_table_level::PD> = PageTable::new();

/// Wrapper for the page table references so that we can access them via a mutex
/// rather than directly via unsafe code.
pub struct PageTableRefs {
/// The root page-map level 4 table coverting virtual memory ranges
/// 0..128TiB and (16EiB-128TiB)..16EiB.
pub pml4: &'static mut PageTable,
pub pml4: &'static mut PageTable<page_table_level::PML4>,

/// The page-directory pointer table covering virtual memory range
/// 0..512GiB.
pub pdpt: &'static mut PageTable,
pub pdpt: &'static mut PageTable<page_table_level::PDPT>,

/// The page directory covering virtual memory range 0..1GiB.
pub pd_0: &'static mut PageTable,
pub pd_0: &'static mut PageTable<page_table_level::PD>,

/// The page directory covering virtual memory range 3..4GiB.
pub pd_3: &'static mut PageTable,
pub pd_3: &'static mut PageTable<page_table_level::PD>,

/// The page table covering virtual memory range 0..2MiB where we want 4KiB
/// pages.
pub pt_0: Box<PageTable, &'static BootAllocator>,
pub pt_0: Box<PageTable<page_table_level::PT>, &'static BootAllocator>,
}

/// References to all the pages tables we care about.
pub static PAGE_TABLE_REFS: OnceCell<Spinlock<PageTableRefs>> = OnceCell::new();

pub mod page_table_level {
use x86_64::structures::paging::{PageSize, PageTableFlags, Size1GiB, Size2MiB, Size4KiB};

pub trait PageTableLevel {}

/// Marker trait for page tables that may have nested page tables.
pub trait Node: PageTableLevel {
/// Type of the lower level of page tables in the hierarchy.
type Nested: PageTableLevel;
}

/// Marker trait for page tables that can map memory.
pub trait Leaf: PageTableLevel {
/// Flags to be used (e.g. HUGE_PAGE for non-4K pages)
const FLAGS: PageTableFlags;

/// Size of the page.
type Size: PageSize;
}

/// Page Map Level 4
#[derive(Clone)]
pub enum PML4 {}
impl PageTableLevel for PML4 {}
impl Node for PML4 {
type Nested = PDPT;
}

/// Page Directory Pointer Table (Level 3)
#[derive(Clone)]
pub enum PDPT {}
impl PageTableLevel for PDPT {}
impl Node for PDPT {
type Nested = PD;
}
impl Leaf for PDPT {
const FLAGS: PageTableFlags = PageTableFlags::HUGE_PAGE;
type Size = Size1GiB;
}

/// Page Directory (Level 2)
#[derive(Clone)]
pub enum PD {}
impl PageTableLevel for PD {}
impl Node for PD {
type Nested = PT;
}
impl Leaf for PD {
const FLAGS: PageTableFlags = PageTableFlags::HUGE_PAGE;
type Size = Size2MiB;
}

/// Page Table (Level 1)
#[derive(Clone)]
pub enum PT {}
impl PageTableLevel for PT {}
impl Leaf for PT {
const FLAGS: PageTableFlags = PageTableFlags::empty();
type Size = Size4KiB;
}
}
pub use page_table_level::{Leaf, PageTableLevel};

/// Thin wrapper around x86::PageTable that uses our PageTableEntry type.
#[repr(transparent)]
#[derive(Default)]
pub struct PageTable(BasePageTable);
pub struct PageTable<PageTableLevel> {
inner: BasePageTable,
phantom: PhantomData<PageTableLevel>,
}

impl PageTable {
impl<L: PageTableLevel> PageTable<L> {
pub const fn new() -> Self {
Self(BasePageTable::new())
Self { inner: BasePageTable::new(), phantom: PhantomData }
}

pub fn zero(&mut self) {
self.0.zero()
self.inner.zero()
}

pub fn iter(&self) -> impl Iterator<Item = &PageTableEntry> {
self.0.iter().map(|entry| entry.into())
pub fn iter(&self) -> impl Iterator<Item = &PageTableEntry<L>> {
self.inner.iter().map(|entry| entry.into())
}

pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut PageTableEntry> {
self.0.iter_mut().map(|entry| entry.into())
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut PageTableEntry<L>> {
self.inner.iter_mut().map(|entry| entry.into())
}
}

impl Index<usize> for PageTable {
type Output = PageTableEntry;
impl<L: PageTableLevel> Index<usize> for PageTable<L> {
type Output = PageTableEntry<L>;

fn index(&self, index: usize) -> &Self::Output {
self.0.index(index).into()
self.inner.index(index).into()
}
}

impl Index<PageTableIndex> for PageTable {
type Output = PageTableEntry;
impl<L: PageTableLevel> Index<PageTableIndex> for PageTable<L> {
type Output = PageTableEntry<L>;

fn index(&self, index: PageTableIndex) -> &Self::Output {
self.0.index(index).into()
self.inner.index(index).into()
}
}

impl IndexMut<usize> for PageTable {
impl<L: PageTableLevel> IndexMut<usize> for PageTable<L> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
self.0.index_mut(index).into()
self.inner.index_mut(index).into()
}
}

impl IndexMut<PageTableIndex> for PageTable {
impl<L: PageTableLevel> IndexMut<PageTableIndex> for PageTable<L> {
fn index_mut(&mut self, index: PageTableIndex) -> &mut Self::Output {
self.0.index_mut(index).into()
self.inner.index_mut(index).into()
}
}

/// Thin wrapper around x86_64::PageTableEntry that forces use of PageEncryption
/// for addresses.
#[repr(transparent)]
#[derive(Clone)]
pub struct PageTableEntry(BasePageTableEntry);
pub struct PageTableEntry<PageTableLevel> {
inner: BasePageTableEntry,
phantom: PhantomData<PageTableLevel>,
}

impl PageTableEntry {
/// Map the entry to the specified address with the specified flags and
/// encryption state.
impl<L: page_table_level::PageTableLevel, Ln: page_table_level::PageTableLevel> PageTableEntry<L>
where
L: page_table_level::Node<Nested = Ln>,
{
/// Sets the entry to point to a lower level page table with the specified
/// flags.
///
/// Encryption bit is never set, as this is not required for neither SEV nor
/// TDX.
pub fn set_lower_level_table<P: Platform>(
&mut self,
pdpt: &PageTable<Ln>,
flags: PageTableFlags,
) {
self.inner.set_addr(PhysAddr::new(pdpt as *const PageTable<Ln> as u64), flags)
}
}

impl<L: PageTableLevel + Leaf> PageTableEntry<L> {
/// Map the entry to the specified address in memory with the specified
/// flags and encryption state.
///
/// You don't need to set the HUGE_PAGE flag; it will be added
/// automatically.
///
/// Don't use this to set up nested page tables!
pub fn set_address<P: Platform>(
&mut self,
addr: PhysAddr,
flags: PageTableFlags,
state: PageEncryption,
) {
let addr = PhysAddr::new(addr.as_u64() | P::page_table_mask(state));
self.0.set_addr(addr, flags);
self.inner.set_addr(addr, flags | L::FLAGS);
}
}

impl<L: PageTableLevel> PageTableEntry<L> {
/// Returns the physical address mapped by this entry. May be zero.
pub fn address<P: Platform>(&self) -> PhysAddr {
PhysAddr::new(self.0.addr().as_u64() & !P::encrypted())
PhysAddr::new(self.inner.addr().as_u64() & !P::encrypted())
}

/// Returns whether the entry is zero.
pub const fn is_unused(&self) -> bool {
self.0.is_unused()
self.inner.is_unused()
}

/// Sets the entry to zero.
pub fn set_unused(&mut self) {
self.0.set_unused()
self.inner.set_unused()
}

/// Returns the flags of this entry.
pub const fn flags(&self) -> PageTableFlags {
self.0.flags()
self.inner.flags()
}
}

impl From<&BasePageTableEntry> for &PageTableEntry {
impl<L: PageTableLevel> From<&BasePageTableEntry> for &PageTableEntry<L> {
fn from(value: &BasePageTableEntry) -> Self {
// Safety: our PageTableEntry is a transparent wrapper so the memory layout is
// the same and does not impose any extra restrictions on valid values.
unsafe { &*(value as *const BasePageTableEntry as *const PageTableEntry) }
unsafe { &*(value as *const BasePageTableEntry as *const PageTableEntry<L>) }
}
}

impl From<&mut BasePageTableEntry> for &mut PageTableEntry {
impl<L: PageTableLevel> From<&mut BasePageTableEntry> for &mut PageTableEntry<L> {
fn from(value: &mut BasePageTableEntry) -> Self {
// Safety: our PageTableEntry is a transparent wrapper so the memory layout is
// the same and does not impose any extra restrictions on valid values.
unsafe { &mut *(value as *mut BasePageTableEntry as *mut PageTableEntry) }
unsafe { &mut *(value as *mut BasePageTableEntry as *mut PageTableEntry<L>) }
}
}

Expand Down Expand Up @@ -211,10 +306,10 @@ pub fn init_page_table_refs<P: Platform>() {
// Safety: accessing the mutable statics here is safe since we only do it once
// and protect the mutable references with a mutex. This function can only
// be called once, since updating `PAGE_TABLE_REFS` twice will panic.
let pml4: &mut PageTable = unsafe { &mut *addr_of_mut!(PML4) };
let pdpt: &mut PageTable = unsafe { &mut *addr_of_mut!(PDPT) };
let pd_0: &mut PageTable = unsafe { &mut *addr_of_mut!(PD_0) };
let pd_3: &mut PageTable = unsafe { &mut *addr_of_mut!(PD_3) };
let pml4: &mut PageTable<page_table_level::PML4> = unsafe { &mut *addr_of_mut!(PML4) };
let pdpt: &mut PageTable<page_table_level::PDPT> = unsafe { &mut *addr_of_mut!(PDPT) };
let pd_0: &mut PageTable<page_table_level::PD> = unsafe { &mut *addr_of_mut!(PD_0) };
let pd_3: &mut PageTable<page_table_level::PD> = unsafe { &mut *addr_of_mut!(PD_3) };

// Set up a new page table that maps the first 2MiB as 4KiB pages (except for
// the lower 4KiB), so that we can share individual 4KiB pages with the
Expand All @@ -232,10 +327,9 @@ pub fn init_page_table_refs<P: Platform>() {
);
});
// Let the first entry of PD_0 point to pt_0:
pd_0[0].set_address::<P>(
PhysAddr::new(pt_0.as_ref() as *const _ as usize as u64),
pd_0[0].set_lower_level_table::<P>(
pt_0.as_ref(),
PageTableFlags::PRESENT | PageTableFlags::WRITABLE,
PageEncryption::Unset,
);

let page_tables = PageTableRefs { pml4, pdpt, pd_0, pd_3, pt_0 };
Expand Down
19 changes: 5 additions & 14 deletions stage0_bin_tdx/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ use oak_linux_boot_params::{BootE820Entry, E820EntryType};
use oak_stage0::{
hal::PortFactory,
mailbox::{FirmwareMailbox, OsMailbox},
paging::{self, PageEncryption},
BOOT_ALLOC,
paging, BOOT_ALLOC,
};
use oak_tdx_guest::{
tdcall::get_td_info,
Expand Down Expand Up @@ -385,28 +384,20 @@ impl oak_stage0::Platform for Tdx {

info!("starting TDX memory acceptance");
let mut page_tables = paging::PAGE_TABLE_REFS.get().unwrap().lock();
let accept_pd_pt = PageTable::new();
let accept_pd_pt = oak_stage0::paging::PageTable::new();
if page_tables.pdpt[1].flags().contains(PageTableFlags::PRESENT) {
panic!("PDPT[1] is in use");
}

page_tables.pdpt[1].set_address::<Tdx>(
PhysAddr::new(&accept_pd_pt as *const _ as u64),
PageTableFlags::PRESENT,
PageEncryption::Unset,
);
page_tables.pdpt[1].set_lower_level_table::<Tdx>(&accept_pd_pt, PageTableFlags::PRESENT);
info!("added pdpt[1]");

info!("adding pd_0[1]");
let accept_pt_pt = PageTable::new();
let accept_pt_pt = oak_stage0::paging::PageTable::new();
if page_tables.pd_0[1].flags().contains(PageTableFlags::PRESENT) {
panic!("PD_0[1] is in use");
}
page_tables.pd_0[1].set_address::<Tdx>(
PhysAddr::new(&accept_pt_pt as *const _ as u64),
PageTableFlags::PRESENT,
PageEncryption::Unset,
);
page_tables.pd_0[1].set_lower_level_table::<Tdx>(&accept_pt_pt, PageTableFlags::PRESENT);
info!("added pd_0[1]");

let min_addr = 0xA0000;
Expand Down
Loading

0 comments on commit bf97164

Please sign in to comment.