Skip to content

Commit

Permalink
virtio-pci: fix potential misaligned issue
Browse files Browse the repository at this point in the history
Misaligned read or write to physical memory or MMIO may cause unexpected
errors.

Signed-off-by: Jiaqi Gao <[email protected]>
  • Loading branch information
gaojiaqi7 authored and jyao1 committed Jan 3, 2025
1 parent b546154 commit e5f308c
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 83 deletions.
138 changes: 81 additions & 57 deletions src/devices/pci/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,20 @@ pub fn pci_cf8_read8(bus: u8, device: u8, fnc: u8, reg: u8) -> u8 {
}
}

fn get_device_details(bus: u8, device: u8, func: u8) -> (u16, u16) {
let config_data = ConfigSpacePciEx::read::<u32>(bus, device, func, 0);
(
fn get_device_details(bus: u8, device: u8, func: u8) -> Result<(u16, u16)> {
let config_data = ConfigSpacePciEx::read::<u32>(bus, device, func, 0)?;
Ok((
(config_data & 0xffff) as u16,
((config_data & 0xffff0000) >> 0x10) as u16,
)
))
}

pub fn find_device(vendor_id: u16, device_id: u16) -> Option<(u8, u8, u8)> {
const MAX_DEVICES: u8 = 32;
const INVALID_VENDOR_ID: u16 = 0xffff;

for device in 0..MAX_DEVICES {
if (vendor_id, device_id) == get_device_details(0, device, 0) {
if (vendor_id, device_id) == get_device_details(0, device, 0).ok()? {
return Some((0, device, 0));
}
if vendor_id == INVALID_VENDOR_ID {
Expand Down Expand Up @@ -191,12 +191,12 @@ impl ConfigSpace {
}

/// Get vendor_id and device_id
pub fn get_device_details(bus: u8, device: u8, func: u8) -> (u16, u16) {
let config_data = ConfigSpacePciEx::read::<u32>(bus, device, func, 0);
(
pub fn get_device_details(bus: u8, device: u8, func: u8) -> Result<(u16, u16)> {
let config_data = ConfigSpacePciEx::read::<u32>(bus, device, func, 0)?;
Ok((
(config_data & 0xffff) as u16,
((config_data & 0xffff0000) >> 0x10) as u16,
)
))
}

fn get_config_address(bus: u8, device: u8, func: u8, offset: u8) -> ConfigAddress {
Expand All @@ -215,49 +215,73 @@ impl ConfigSpace {
pub struct ConfigSpacePciEx;
impl ConfigSpacePciEx {
#[cfg(not(feature = "fuzz"))]
pub fn read<T: Copy + Clone>(bus: u8, device: u8, func: u8, offset: u16) -> T {
pub fn read<T: Copy + Clone>(bus: u8, device: u8, func: u8, offset: u16) -> Result<T> {
let addr = PCI_EX_BAR_BASE_ADDRESS
+ ((bus as u64) << 20)
+ ((device as u64) << 15)
+ ((func as u64) << 12)
+ offset as u64;

if addr % align_of::<T>() as u64 != 0 {
return Err(PciError::Misaligned);
}
#[cfg(feature = "iocall")]
unsafe {
core::ptr::read_volatile(addr as *const T)
Ok(core::ptr::read_volatile(addr as *const T))
}
#[cfg(feature = "tdcall")]
tdx_tdcall::tdx::tdvmcall_mmio_read(addr as usize)
Ok(tdx_tdcall::tdx::tdvmcall_mmio_read(addr as usize))
}
#[cfg(feature = "fuzz")]
pub fn read<T: Copy + Clone>(_bus: u8, _device: u8, _func: u8, offset: u16) -> T {
pub fn read<T: Copy + Clone>(_bus: u8, _device: u8, _func: u8, offset: u16) -> Result<T> {
let base_address = crate::get_fuzz_seed_address();
let address = base_address + offset as u64;
unsafe { core::ptr::read_volatile(address as *const T) }
if address % align_of::<T>() as u64 != 0 {
return Err(PciError::Misaligned);
}
unsafe { Ok(core::ptr::read_volatile(address as *const T)) }
}

#[cfg(not(feature = "fuzz"))]
pub fn write<T: Copy + Clone>(bus: u8, device: u8, func: u8, offset: u16, value: T) {
pub fn write<T: Copy + Clone>(
bus: u8,
device: u8,
func: u8,
offset: u16,
value: T,
) -> Result<()> {
let addr = PCI_EX_BAR_BASE_ADDRESS
+ ((bus as u64) << 20)
+ ((device as u64) << 15)
+ ((func as u64) << 12)
+ offset as u64;

if addr % align_of::<T>() as u64 != 0 {
return Err(PciError::Misaligned);
}
#[cfg(feature = "iocall")]
unsafe {
core::ptr::write_volatile(addr as *mut T, value)
core::ptr::write_volatile(addr as *mut T, value);
}
#[cfg(feature = "tdcall")]
tdx_tdcall::tdx::tdvmcall_mmio_write(addr as *mut T, value);
Ok(())
}

#[cfg(feature = "fuzz")]
pub fn write<T: Copy + Clone>(_bus: u8, _device: u8, _func: u8, offset: u16, value: T) {
unsafe {
let base_address = crate::get_fuzz_seed_address();
let address = base_address + offset as u64;
core::ptr::write_volatile(address as *mut T, value)
pub fn write<T: Copy + Clone>(
_bus: u8,
_device: u8,
_func: u8,
offset: u16,
value: T,
) -> Result<()> {
let base_address = crate::get_fuzz_seed_address();
let address = base_address + offset as u64;
if address % align_of::<T>() as u64 != 0 {
return Err(PciError::Misaligned);
}
unsafe { Ok(core::ptr::write_volatile(address as *mut T, value)) }
}
}

Expand Down Expand Up @@ -384,11 +408,11 @@ impl PciDevice {
#[cfg(not(feature = "fuzz"))]
pub fn init(&mut self) -> Result<()> {
let (vendor_id, device_id) =
ConfigSpace::get_device_details(self.bus, self.device, self.func);
ConfigSpace::get_device_details(self.bus, self.device, self.func)?;
self.common_header.vendor_id = vendor_id;
self.common_header.device_id = device_id;
let command = self.read_u16(0x4);
let status = self.read_u16(0x6);
let command = self.read_u16(0x4)?;
let status = self.read_u16(0x6)?;
log::info!(
"PCI Device: {}:{}.{} {:x}:{:x}\nbit \t fedcba9876543210\nstate\t {:016b}\ncommand\t {:016b}\n",
self.bus,
Expand All @@ -405,7 +429,7 @@ impl PciDevice {

//0x24 offset is last bar
while current_bar_offset <= 0x24 {
let bar = self.read_u32(current_bar_offset);
let bar = self.read_u32(current_bar_offset)?;

// lsb is 1 for I/O space bars
if bar & 1 == 1 {
Expand All @@ -415,11 +439,11 @@ impl PciDevice {
// bits 2-1 are the type 0 is 32-but, 2 is 64 bit
match bar >> 1 & 3 {
0 => {
let size = self.get_bar_size(current_bar_offset);
let size = self.get_bar_size(current_bar_offset)?;

let addr = if size > 0 {
let addr = alloc_mmio32(size)?;
self.set_bar_addr(current_bar_offset, addr);
self.set_bar_addr(current_bar_offset, addr)?;
addr
} else {
bar
Expand All @@ -432,15 +456,15 @@ impl PciDevice {
2 => {
self.bars[current_bar].bar_type = PciBarType::MemorySpace64;

let mut size = self.get_bar_size(current_bar_offset) as u64;
let mut size = self.get_bar_size(current_bar_offset)? as u64;
if size == 0 {
size = (self.get_bar_size(current_bar_offset + 4) as u64) << 32;
size = (self.get_bar_size(current_bar_offset + 4)? as u64) << 32;
}

let addr = if size > 0 {
let addr = alloc_mmio64(size)?;
self.set_bar_addr(current_bar_offset, addr as u32);
self.set_bar_addr(current_bar_offset + 4, (addr >> 32) as u32);
self.set_bar_addr(current_bar_offset, addr as u32)?;
self.set_bar_addr(current_bar_offset + 4, (addr >> 32) as u32)?;
addr
} else {
bar as u64
Expand All @@ -461,7 +485,7 @@ impl PciDevice {
self.write_u16(
0x4,
(PciCommand::IO_SPACE | PciCommand::MEMORY_SPACE | PciCommand::BUS_MASTER).bits(),
);
)?;
for bar in &self.bars {
log::info!("Bar: type={:?} address={:x}\n", bar.bar_type, bar.address);
}
Expand All @@ -472,18 +496,18 @@ impl PciDevice {
#[cfg(feature = "fuzz")]
pub fn init(&mut self) -> Result<()> {
let (vendor_id, device_id) =
ConfigSpace::get_device_details(self.bus, self.device, self.func);
ConfigSpace::get_device_details(self.bus, self.device, self.func)?;
self.common_header.vendor_id = vendor_id;
self.common_header.device_id = device_id;
let command = self.read_u16(0x4);
let status = self.read_u16(0x6);
let command = self.read_u16(0x4)?;
let status = self.read_u16(0x6)?;

let mut current_bar_offset = 0x10;
let mut current_bar = 0;

//0x24 offset is last bar
while current_bar_offset <= 0x24 {
let bar = self.read_u32(current_bar_offset);
let bar = self.read_u32(current_bar_offset)?;

// lsb is 1 for I/O space bars
if bar & 1 == 1 {
Expand All @@ -493,11 +517,11 @@ impl PciDevice {
// bits 2-1 are the type 0 is 32-but, 2 is 64 bit
match bar >> 1 & 3 {
0 => {
let size = self.read_u32(current_bar_offset);
let size = self.read_u32(current_bar_offset)?;

let addr = if size > 0 {
let addr = alloc_mmio32(size)?;
self.set_bar_addr(current_bar_offset, addr);
self.set_bar_addr(current_bar_offset, addr)?;
addr
} else {
bar
Expand All @@ -510,11 +534,11 @@ impl PciDevice {
2 => {
self.bars[current_bar].bar_type = PciBarType::MemorySpace64;

let mut size = self.read_u64(current_bar_offset);
let mut size = self.read_u64(current_bar_offset)?;
let addr = if size > 0 {
let addr = alloc_mmio64(size)?;
self.set_bar_addr(current_bar_offset, addr as u32);
self.set_bar_addr(current_bar_offset + 4, (addr >> 32) as u32);
self.set_bar_addr(current_bar_offset, addr as u32)?;
self.set_bar_addr(current_bar_offset + 4, (addr >> 32) as u32)?;
addr
} else {
bar as u64
Expand All @@ -540,56 +564,56 @@ impl PciDevice {
Ok(())
}

fn set_bar_addr(&self, offset: u8, addr: u32) {
self.write_u32(offset, addr);
fn set_bar_addr(&self, offset: u8, addr: u32) -> Result<()> {
self.write_u32(offset, addr)
}

fn get_bar_size(&self, offset: u8) -> u32 {
let restore = self.read_u32(offset);
self.write_u32(offset, u32::MAX);
let size = self.read_u32(offset);
self.write_u32(offset, restore);
fn get_bar_size(&self, offset: u8) -> Result<u32> {
let restore = self.read_u32(offset)?;
self.write_u32(offset, u32::MAX)?;
let size = self.read_u32(offset)?;
self.write_u32(offset, restore)?;

if size == 0 {
Ok(if size == 0 {
size
} else {
!(size & 0xFFFF_FFF0) + 1
}
})
}

pub fn read_u64(&self, offset: u8) -> u64 {
pub fn read_u64(&self, offset: u8) -> Result<u64> {
ConfigSpacePciEx::read::<u64>(self.bus, self.device, self.func, offset as u16)
// let low = ConfigSpace::read32(self.bus, self.device, self.func, offset);
// let high = ConfigSpace::read32(self.bus, self.device, self.func, offset + 8);
// (low as u64) & ((high as u64) << 8)
}

pub fn read_u32(&self, offset: u8) -> u32 {
pub fn read_u32(&self, offset: u8) -> Result<u32> {
ConfigSpacePciEx::read::<u32>(self.bus, self.device, self.func, offset as u16)
// ConfigSpace::read32(self.bus, self.device, self.func, offset)
}

pub fn read_u16(&self, offset: u8) -> u16 {
pub fn read_u16(&self, offset: u8) -> Result<u16> {
ConfigSpacePciEx::read::<u16>(self.bus, self.device, self.func, offset as u16)
// ConfigSpace::read16(self.bus, self.device, self.func, offset)
}

pub fn read_u8(&self, offset: u8) -> u8 {
pub fn read_u8(&self, offset: u8) -> Result<u8> {
ConfigSpacePciEx::read::<u8>(self.bus, self.device, self.func, offset as u16)
// ConfigSpace::read8(self.bus, self.device, self.func, offset)
}

pub fn write_u32(&self, offset: u8, value: u32) {
pub fn write_u32(&self, offset: u8, value: u32) -> Result<()> {
ConfigSpacePciEx::write::<u32>(self.bus, self.device, self.func, offset as u16, value)
// ConfigSpace::write32(self.bus, self.device, self.func, offset, value)
}

pub fn write_u16(&self, offset: u8, value: u16) {
pub fn write_u16(&self, offset: u8, value: u16) -> Result<()> {
ConfigSpacePciEx::write::<u16>(self.bus, self.device, self.func, offset as u16, value)
// ConfigSpace::write16(self.bus, self.device, self.func, offset, value)
}

pub fn write_u8(&self, offset: u8, value: u8) {
pub fn write_u8(&self, offset: u8, value: u8) -> Result<()> {
ConfigSpacePciEx::write::<u8>(self.bus, self.device, self.func, offset as u16, value)
// ConfigSpace::write8(self.bus, self.device, self.func, offset, value)
}
Expand Down
2 changes: 2 additions & 0 deletions src/devices/pci/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ pub fn get_fuzz_seed_address() -> u64 {

pub type Result<T> = core::result::Result<T, PciError>;

#[derive(Debug)]
pub enum PciError {
InvalidParameter,
MmioOutofResource,
InvalidBarType,
Misaligned,
}
10 changes: 10 additions & 0 deletions src/devices/virtio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
extern crate alloc;
use core::fmt::Display;
use mem::MemoryRegionError;
use pci::PciError;

pub mod consts;
mod mem;
Expand Down Expand Up @@ -47,6 +48,8 @@ pub enum VirtioError {
InvalidRingIndex,
/// Invalid index for ring
InvalidDescriptor,
/// Pci related error
Pci(PciError),
}

impl Display for VirtioError {
Expand All @@ -69,6 +72,7 @@ impl Display for VirtioError {
VirtioError::InvalidDescriptorIndex => write!(f, "InvalidDescriptorIndex"),
VirtioError::InvalidRingIndex => write!(f, "InvalidRingIndex"),
VirtioError::InvalidDescriptor => write!(f, "InvalidDescriptor"),
VirtioError::Pci(_) => write!(f, "Pci"),
}
}
}
Expand All @@ -79,6 +83,12 @@ impl From<mem::MemoryRegionError> for VirtioError {
}
}

impl From<PciError> for VirtioError {
fn from(e: PciError) -> Self {
VirtioError::Pci(e)
}
}

pub type Result<T = ()> = core::result::Result<T, VirtioError>;

/// Trait to allow separation of transport from block driver
Expand Down
Loading

0 comments on commit e5f308c

Please sign in to comment.