diff --git a/src/drivers/virtio/transport/pci.rs b/src/drivers/virtio/transport/pci.rs index f77c478706..961d25e093 100644 --- a/src/drivers/virtio/transport/pci.rs +++ b/src/drivers/virtio/transport/pci.rs @@ -702,9 +702,10 @@ impl NotifCfg { // Assumes the cap_len is a multiple of 8 // This read MIGHT be slow, as it does NOT ensure 32 bit alignment. - let notify_off_multiplier = cap.device.read_register( - u16::try_from(cap.origin.cfg_ptr).unwrap() + u16::from(cap.origin.cap_struct.cap_len), - ); + let notify_off_multiplier_ptr = + cap.origin.cfg_ptr + u32::try_from(mem::size_of::()).unwrap(); + let notify_off_multiplier_ptr = u16::try_from(notify_off_multiplier_ptr).unwrap(); + let notify_off_multiplier = cap.device.read_register(notify_off_multiplier_ptr); // define base memory address from which the actual Queue Notify address can be derived via // base_addr + queue_notify_off * notify_off_multiplier. @@ -921,8 +922,8 @@ impl PciCfgAlt { // #[repr(C)] // struct PciCap64 { // pci_cap: PciCap, -// offset_high: u32, -// length_high: u32 +// offset_hi: u32, +// length_hi: u32 pub struct ShMemCfg { mem_addr: VirtMemAddr, length: MemLen, @@ -947,24 +948,25 @@ impl ShMemCfg { // Assumes the cap_len is a multiple of 8 // This read MIGHT be slow, as it does NOT ensure 32 bit alignment. - let offset_high = cap.device.read_register( - u16::try_from(cap.origin.cfg_ptr).unwrap() + u16::from(cap.origin.cap_struct.cap_len), - ); + let offset_hi_ptr = + cap.origin.cfg_ptr + u32::try_from(mem::size_of::()).unwrap(); + let offset_hi_ptr = u16::try_from(offset_hi_ptr).unwrap(); + let offset_hi = cap.device.read_register(offset_hi_ptr); // Create 64 bit offset from high and low 32 bit values let offset = - MemOff::from((u64::from(offset_high) << 32) ^ u64::from(cap.origin.cap_struct.offset)); + MemOff::from((u64::from(offset_hi) << 32) ^ u64::from(cap.origin.cap_struct.offset)); // Assumes the cap_len is a multiple of 8 // This read MIGHT be slow, as it does NOT ensure 32 bit alignment. - let length_high = cap.device.read_register( - u16::try_from(cap.origin.cfg_ptr).unwrap() - + u16::from(cap.origin.cap_struct.cap_len + 4), - ); + let length_hi_ptr = cap.origin.cfg_ptr + + u32::try_from(mem::size_of::() + mem::size_of::()).unwrap(); + let length_hi_ptr = u16::try_from(length_hi_ptr).unwrap(); + let length_hi = cap.device.read_register(length_hi_ptr); // Create 64 bit length from high and low 32 bit values let length = - MemLen::from((u64::from(length_high) << 32) ^ u64::from(cap.origin.cap_struct.length)); + MemLen::from((u64::from(length_hi) << 32) ^ u64::from(cap.origin.cap_struct.length)); let virt_addr_raw = cap.bar.mem_addr + offset; let raw_ptr = ptr::with_exposed_provenance_mut::(virt_addr_raw.into()); @@ -1118,6 +1120,7 @@ fn read_caps( let mut iter = bars.iter(); + let cfg_ptr = next_ptr; // Set next pointer for next iteration of `caplist. next_ptr = u32::from(cap_raw.cap_next); @@ -1151,7 +1154,7 @@ fn read_caps( length: MemLen::from(cap_raw.length), device: *device, origin: Origin { - cfg_ptr: next_ptr, + cfg_ptr, dev_id: device_id, cap_struct: cap_raw, },