From a4613f44435625a13d4a03030083a564ee63391f Mon Sep 17 00:00:00 2001 From: Jiaqi Gao Date: Thu, 21 Nov 2024 04:07:13 -0500 Subject: [PATCH] migtd: create a async task for waiting request MigTD should wait until the wait_for_request command get response. The buffer used for last command should be kept to receive the VMM response. Signed-off-by: Jiaqi Gao --- src/migtd/src/bin/migtd/main.rs | 75 ++- src/migtd/src/migration/session.rs | 711 ++++++++++------------- tests/test-td-payload/src/testservice.rs | 4 +- 3 files changed, 361 insertions(+), 429 deletions(-) diff --git a/src/migtd/src/bin/migtd/main.rs b/src/migtd/src/bin/migtd/main.rs index b407cbff..83c100e8 100644 --- a/src/migtd/src/bin/migtd/main.rs +++ b/src/migtd/src/bin/migtd/main.rs @@ -7,9 +7,14 @@ extern crate alloc; +use core::future::poll_fn; +use core::task::Poll; + use log::info; -use migtd::migration::{session::MigrationSession, MigrationResult}; +use migtd::migration::session::*; +use migtd::migration::MigrationResult; use migtd::{config, event_log, migration}; +use spin::Mutex; const MIGTD_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -49,7 +54,7 @@ pub fn runtime_main() { migration::event::register_callback(); // Query the capability of VMM - if MigrationSession::query().is_err() { + if query().is_err() { panic!("Migration is not supported by VMM"); } @@ -92,37 +97,55 @@ fn get_ca_and_measure(event_log: &mut [u8]) { } fn handle_pre_mig() { - use migtd::migration::session::REQUESTS; #[cfg(feature = "vmcall-interrupt")] const MAX_CONCURRENCY_REQUESTS: usize = 16; #[cfg(not(feature = "vmcall-interrupt"))] const MAX_CONCURRENCY_REQUESTS: usize = 1; + // Set by `wait_for_request` async task when getting new request from VMM. + static PENDING_REQUEST: Mutex> = Mutex::new(None); + + async_runtime::add_task(async move { + loop { + poll_fn(|_cx| { + // Wait until the pending request is taken by a new task + if PENDING_REQUEST.lock().is_none() { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + + if let Ok(request) = wait_for_request().await { + *PENDING_REQUEST.lock() = Some(request); + } + } + }); + let mut queued = async_runtime::poll_tasks(); + loop { - if queued < MAX_CONCURRENCY_REQUESTS { - let mut session = MigrationSession::new(); - if let Ok(info) = session.wait_for_request() { - if let Some(request_id) = info { - async_runtime::add_task(async move { - #[cfg(feature = "vmcall-vsock")] - { - // Safe to unwrap because we have got the request information - let info = session.info().unwrap(); - migtd::driver::vsock::vmcall_vsock_device_init( - info.mig_info.mig_request_id, - info.mig_socket_info.mig_td_cid, - ); - } - let status = session - .op() - .await - .map(|_| MigrationResult::Success) - .unwrap_or_else(|e| e); - let _ = session.report_status(status as u8); - REQUESTS.lock().remove(&request_id); - }); - } + // The async task waiting for VMM response is always in the queue + if queued < MAX_CONCURRENCY_REQUESTS + 1 { + let new_request = PENDING_REQUEST.lock().take(); + + if let Some(request) = new_request { + async_runtime::add_task(async move { + #[cfg(feature = "vmcall-vsock")] + { + migtd::driver::vsock::vmcall_vsock_device_init( + request.mig_info.mig_request_id, + request.mig_socket_info.mig_td_cid, + ); + } + let status = exchange_msk(&request) + .await + .map(|_| MigrationResult::Success) + .unwrap_or_else(|e| e); + let _ = report_status(status as u8, request.mig_info.mig_request_id); + REQUESTS.lock().remove(&request.mig_info.mig_request_id); + }); } } queued = async_runtime::poll_tasks(); diff --git a/src/migtd/src/migration/session.rs b/src/migtd/src/migration/session.rs index 07245465..312defed 100644 --- a/src/migtd/src/migration/session.rs +++ b/src/migtd/src/migration/session.rs @@ -3,7 +3,10 @@ // SPDX-License-Identifier: BSD-2-Clause-Patent use alloc::{collections::BTreeSet, vec::Vec}; -use core::{mem::size_of, sync::atomic::Ordering}; +#[cfg(feature = "vmcall-interrupt")] +use core::sync::atomic::Ordering; +use core::{future::poll_fn, mem::size_of, task::Poll}; +#[cfg(feature = "vmcall-interrupt")] use event::VMCALL_SERVICE_FLAG; use lazy_static::lazy_static; use scroll::Pread; @@ -32,7 +35,6 @@ const GSM_FIELD_MAX_EXPORT_VERSION: u64 = 0x2000000100000002; const GSM_FIELD_MIN_IMPORT_VERSION: u64 = 0x2000000100000003; const GSM_FIELD_MAX_IMPORT_VERSION: u64 = 0x2000000100000004; -// #[cfg(feature = "async")] lazy_static! { pub static ref REQUESTS: Mutex> = Mutex::new(BTreeSet::new()); } @@ -49,12 +51,6 @@ impl MigrationInformation { } } -#[derive(Debug, Clone, Copy)] -struct RequestInformation { - request_id: u64, - operation: u8, -} - struct ExchangeInformation { min_ver: u16, max_ver: u16, @@ -81,458 +77,371 @@ impl ExchangeInformation { } } -struct WaitForRequestContext { - rsp_mem: Option, - pending_response: bool, -} +pub fn query() -> Result<()> { + // Allocate one shared page for command and response buffer + let mut cmd_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + let mut rsp_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + + // Set Migration query command buffer + let mut cmd = VmcallServiceCommand::new(cmd_mem.as_mut_bytes(), VMCALL_SERVICE_COMMON_GUID) + .ok_or(MigrationResult::InvalidParameter)?; + let query = ServiceMigWaitForReqCommand { + version: 0, + command: QUERY_COMMAND, + reserved: [0; 2], + }; + cmd.write(query.as_bytes())?; + cmd.write(VMCALL_SERVICE_MIGTD_GUID.as_bytes())?; + let _ = VmcallServiceResponse::new(rsp_mem.as_mut_bytes(), VMCALL_SERVICE_COMMON_GUID) + .ok_or(MigrationResult::InvalidParameter)?; -enum MigrationState { - WaitForRequest(WaitForRequestContext), - Operate(MigrationOperation), - Complete(RequestInformation), -} + #[cfg(feature = "vmcall-interrupt")] + { + tdx::tdvmcall_service( + cmd_mem.as_bytes(), + rsp_mem.as_mut_bytes(), + event::VMCALL_SERVICE_VECTOR as u64, + 0, + )?; + event::wait_for_event(&event::VMCALL_SERVICE_FLAG); + } + #[cfg(not(feature = "vmcall-interrupt"))] + tdx::tdvmcall_service(cmd_mem.as_bytes(), rsp_mem.as_mut_bytes(), 0, 0)?; -enum MigrationOperation { - Migrate(MigrationInformation), -} + let private_mem = rsp_mem.copy_to_private_shadow(); -pub struct MigrationSession { - state: MigrationState, -} + // Parse the response data + // Check the GUID of the reponse + let rsp = + VmcallServiceResponse::try_read(private_mem).ok_or(MigrationResult::InvalidParameter)?; + if rsp.read_guid() != VMCALL_SERVICE_COMMON_GUID.as_bytes() { + return Err(MigrationResult::InvalidParameter); + } + let query = rsp + .read_data::(0) + .ok_or(MigrationResult::InvalidParameter)?; -impl Default for MigrationSession { - fn default() -> Self { - Self::new() + if query.command != QUERY_COMMAND || &query.guid != VMCALL_SERVICE_MIGTD_GUID.as_bytes() { + return Err(MigrationResult::InvalidParameter); + } + if query.status != 0 { + return Err(MigrationResult::Unsupported); } + + log::info!("Migration is supported by VMM\n"); + Ok(()) } -impl MigrationSession { - pub fn new() -> Self { - MigrationSession { - state: MigrationState::WaitForRequest(WaitForRequestContext { - rsp_mem: None, - pending_response: false, - }), - } - } +pub async fn wait_for_request() -> Result { + // Allocate shared page for command and response buffer + let mut cmd_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + let mut rsp_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + + // Set Migration wait for request command buffer + let mut cmd = VmcallServiceCommand::new(cmd_mem.as_mut_bytes(), VMCALL_SERVICE_MIGTD_GUID) + .ok_or(MigrationResult::InvalidParameter)?; + let wfr = ServiceMigWaitForReqCommand { + version: 0, + command: MIG_COMMAND_WAIT, + reserved: [0; 2], + }; + cmd.write(wfr.as_bytes())?; + let _ = VmcallServiceResponse::new(rsp_mem.as_mut_bytes(), VMCALL_SERVICE_MIGTD_GUID) + .ok_or(MigrationResult::InvalidParameter)?; - pub fn query() -> Result<()> { - // Allocate one shared page for command and response buffer - let mut cmd_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; - let mut rsp_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + #[cfg(feature = "vmcall-interrupt")] + { + tdx::tdvmcall_service( + cmd_mem.as_bytes(), + rsp_mem.as_mut_bytes(), + event::VMCALL_SERVICE_VECTOR as u64, + 0, + )?; + } - // Set Migration query command buffer - let mut cmd = VmcallServiceCommand::new(cmd_mem.as_mut_bytes(), VMCALL_SERVICE_COMMON_GUID) - .ok_or(MigrationResult::InvalidParameter)?; - let query = ServiceMigWaitForReqCommand { - version: 0, - command: QUERY_COMMAND, - reserved: [0; 2], - }; - cmd.write(query.as_bytes())?; - cmd.write(VMCALL_SERVICE_MIGTD_GUID.as_bytes())?; - let _ = VmcallServiceResponse::new(rsp_mem.as_mut_bytes(), VMCALL_SERVICE_COMMON_GUID) - .ok_or(MigrationResult::InvalidParameter)?; + poll_fn(|_cx| { + #[cfg(not(feature = "vmcall-interrupt"))] + tdx::tdvmcall_service(cmd_mem.as_bytes(), rsp_mem.as_mut_bytes(), 0, 0)?; #[cfg(feature = "vmcall-interrupt")] - { - tdx::tdvmcall_service( - cmd_mem.as_bytes(), - rsp_mem.as_mut_bytes(), - event::VMCALL_SERVICE_VECTOR as u64, - 0, - )?; - event::wait_for_event(&event::VMCALL_SERVICE_FLAG); + if VMCALL_SERVICE_FLAG.load(Ordering::SeqCst) { + VMCALL_SERVICE_FLAG.store(false, Ordering::SeqCst); + } else { + return Poll::Pending; } - #[cfg(not(feature = "vmcall-interrupt"))] - tdx::tdvmcall_service(cmd_mem.as_bytes(), rsp_mem.as_mut_bytes(), 0, 0)?; let private_mem = rsp_mem.copy_to_private_shadow(); - // Parse the response data - // Check the GUID of the reponse + // Parse out the response data let rsp = VmcallServiceResponse::try_read(private_mem) .ok_or(MigrationResult::InvalidParameter)?; - if rsp.read_guid() != VMCALL_SERVICE_COMMON_GUID.as_bytes() { - return Err(MigrationResult::InvalidParameter); + // Check the GUID of the reponse + if rsp.read_guid() != VMCALL_SERVICE_MIGTD_GUID.as_bytes() { + return Poll::Ready(Err(MigrationResult::InvalidParameter)); } - let query = rsp - .read_data::(0) + let wfr = rsp + .read_data::(0) .ok_or(MigrationResult::InvalidParameter)?; - - if query.command != QUERY_COMMAND || &query.guid != VMCALL_SERVICE_MIGTD_GUID.as_bytes() { - return Err(MigrationResult::InvalidParameter); + if wfr.command != MIG_COMMAND_WAIT { + return Poll::Ready(Err(MigrationResult::InvalidParameter)); } - if query.status != 0 { - return Err(MigrationResult::Unsupported); - } - - log::info!("Migration is supported by VMM\n"); - Ok(()) - } - - pub fn wait_for_request(&mut self) -> Result> { - match &mut self.state { - MigrationState::WaitForRequest(context) => { - if !context.pending_response || cfg!(not(feature = "vmcall-interrupt")) { - // Allocate shared page for command and response buffer - let mut cmd_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; - let mut rsp_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; - - // Set Migration wait for request command buffer - let mut cmd = VmcallServiceCommand::new( - cmd_mem.as_mut_bytes(), - VMCALL_SERVICE_MIGTD_GUID, - ) - .ok_or(MigrationResult::InvalidParameter)?; - let wfr = ServiceMigWaitForReqCommand { - version: 0, - command: MIG_COMMAND_WAIT, - reserved: [0; 2], - }; - cmd.write(wfr.as_bytes())?; - let _ = VmcallServiceResponse::new( - rsp_mem.as_mut_bytes(), - VMCALL_SERVICE_MIGTD_GUID, - ) + if wfr.operation == 1 { + let mig_info = + read_mig_info(&private_mem[24 + size_of::()..]) .ok_or(MigrationResult::InvalidParameter)?; + let request_id = mig_info.mig_info.mig_request_id; - #[cfg(not(feature = "vmcall-interrupt"))] - tdx::tdvmcall_service(cmd_mem.as_bytes(), rsp_mem.as_mut_bytes(), 0, 0)?; - - #[cfg(feature = "vmcall-interrupt")] - { - tdx::tdvmcall_service( - cmd_mem.as_bytes(), - rsp_mem.as_mut_bytes(), - event::VMCALL_SERVICE_VECTOR as u64, - 0, - )?; - } - - context.rsp_mem = Some(rsp_mem); - context.pending_response = true; - } - - #[cfg(feature = "vmcall-interrupt")] - if VMCALL_SERVICE_FLAG.load(Ordering::SeqCst) { - VMCALL_SERVICE_FLAG.store(false, Ordering::SeqCst); - } else { - return Ok(None); - } - - let private_mem = context - .rsp_mem - .as_mut() - .ok_or(MigrationResult::InvalidParameter)? - .copy_to_private_shadow(); - - // Parse out the response data - let rsp = VmcallServiceResponse::try_read(private_mem) - .ok_or(MigrationResult::InvalidParameter)?; - // Check the GUID of the reponse - if rsp.read_guid() != VMCALL_SERVICE_MIGTD_GUID.as_bytes() { - return Err(MigrationResult::InvalidParameter); - } - let wfr = rsp - .read_data::(0) - .ok_or(MigrationResult::InvalidParameter)?; - if wfr.command != MIG_COMMAND_WAIT { - return Err(MigrationResult::InvalidParameter); - } - if wfr.operation == 1 { - let mig_info = Self::read_mig_info( - &private_mem[24 + size_of::()..], - ) - .ok_or(MigrationResult::InvalidParameter)?; - let request_id = mig_info.mig_info.mig_request_id; - - if REQUESTS.lock().contains(&request_id) { - Ok(None) - } else { - self.state = MigrationState::Operate(MigrationOperation::Migrate(mig_info)); - REQUESTS.lock().insert(request_id); - Ok(Some(request_id)) - } - } else if wfr.operation == 0 { - Ok(None) - } else { - Err(MigrationResult::InvalidParameter) - } + if REQUESTS.lock().contains(&request_id) { + Poll::Pending + } else { + REQUESTS.lock().insert(request_id); + Poll::Ready(Ok(mig_info)) } - _ => Err(MigrationResult::InvalidParameter), + } else if wfr.operation == 0 { + Poll::Pending + } else { + Poll::Ready(Err(MigrationResult::InvalidParameter)) } - } + }) + .await +} - pub fn info(&self) -> Option<&MigrationInformation> { - match &self.state { - MigrationState::Operate(operation) => match operation { - MigrationOperation::Migrate(info) => Some(info), - }, - _ => None, - } - } +pub fn shutdown() -> Result<()> { + // Allocate shared page for command and response buffer + let mut cmd_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + let mut rsp_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; - #[cfg(feature = "main")] - pub async fn op(&mut self) -> Result<()> { - match &self.state { - MigrationState::Operate(operation) => match operation { - MigrationOperation::Migrate(info) => { - let state = Self::migrate(info).await; - self.state = MigrationState::Complete(RequestInformation { - request_id: info.mig_info.mig_request_id, - operation: 1, - }); - - state - } - }, - _ => Err(MigrationResult::InvalidParameter), - } - } + // Set Command + let mut cmd = VmcallServiceCommand::new(cmd_mem.as_mut_bytes(), VMCALL_SERVICE_MIGTD_GUID) + .ok_or(MigrationResult::InvalidParameter)?; - pub fn shutdown() -> Result<()> { - // Allocate shared page for command and response buffer - let mut cmd_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; - let mut rsp_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + let sd = ServiceMigWaitForReqShutdown { + version: 0, + command: MIG_COMMAND_SHUT_DOWN, + reserved: [0; 2], + }; + cmd.write(sd.as_bytes())?; + tdx::tdvmcall_service(cmd_mem.as_bytes(), rsp_mem.as_mut_bytes(), 0, 0)?; + Ok(()) +} - // Set Command - let mut cmd = VmcallServiceCommand::new(cmd_mem.as_mut_bytes(), VMCALL_SERVICE_MIGTD_GUID) - .ok_or(MigrationResult::InvalidParameter)?; +pub fn report_status(status: u8, request_id: u64) -> Result<()> { + // Allocate shared page for command and response buffer + let mut cmd_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + let mut rsp_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; + + // Set Command + let mut cmd = VmcallServiceCommand::new(cmd_mem.as_mut_bytes(), VMCALL_SERVICE_MIGTD_GUID) + .ok_or(MigrationResult::InvalidParameter)?; + + let rs = ServiceMigReportStatusCommand { + version: 0, + command: MIG_COMMAND_REPORT_STATUS, + operation: 1, + status, + mig_request_id: request_id, + }; - let sd = ServiceMigWaitForReqShutdown { - version: 0, - command: MIG_COMMAND_SHUT_DOWN, - reserved: [0; 2], - }; - cmd.write(sd.as_bytes())?; - tdx::tdvmcall_service(cmd_mem.as_bytes(), rsp_mem.as_mut_bytes(), 0, 0)?; - Ok(()) + cmd.write(rs.as_bytes())?; + + let _ = VmcallServiceResponse::new(rsp_mem.as_mut_bytes(), VMCALL_SERVICE_MIGTD_GUID) + .ok_or(MigrationResult::InvalidParameter)?; + + tdx::tdvmcall_service(cmd_mem.as_bytes(), rsp_mem.as_mut_bytes(), 0, 0)?; + + let private_mem = rsp_mem.copy_to_private_shadow(); + + // Parse the response data + // Check the GUID of the reponse + let rsp = + VmcallServiceResponse::try_read(private_mem).ok_or(MigrationResult::InvalidParameter)?; + if rsp.read_guid() != VMCALL_SERVICE_MIGTD_GUID.as_bytes() { + return Err(MigrationResult::InvalidParameter); } + let query = rsp + .read_data::(0) + .ok_or(MigrationResult::InvalidParameter)?; - pub fn report_status(&self, status: u8) -> Result<()> { - let request = match &self.state { - MigrationState::Complete(request) => *request, - _ => return Err(MigrationResult::InvalidParameter), - }; + // Ensure the response matches the command + if query.command != MIG_COMMAND_REPORT_STATUS { + return Err(MigrationResult::InvalidParameter); + } + Ok(()) +} - // Allocate shared page for command and response buffer - let mut cmd_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; - let mut rsp_mem = SharedMemory::new(1).ok_or(MigrationResult::OutOfResource)?; +#[cfg(feature = "main")] +pub async fn exchange_msk(info: &MigrationInformation) -> Result<()> { + use crate::driver::ticks::with_timeout; + use core::time::Duration; - // Set Command - let mut cmd = VmcallServiceCommand::new(cmd_mem.as_mut_bytes(), VMCALL_SERVICE_MIGTD_GUID) - .ok_or(MigrationResult::InvalidParameter)?; + const TLS_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds - let rs = ServiceMigReportStatusCommand { - version: 0, - command: MIG_COMMAND_REPORT_STATUS, - operation: request.operation, - status, - mig_request_id: request.request_id, - }; + let transport; + #[cfg(feature = "virtio-serial")] + { + use virtio_serial::VirtioSerialPort; + const VIRTIO_SERIAL_PORT_ID: u32 = 1; - cmd.write(rs.as_bytes())?; + let port = VirtioSerialPort::new(VIRTIO_SERIAL_PORT_ID); + port.open()?; + transport = port; + }; - let _ = VmcallServiceResponse::new(rsp_mem.as_mut_bytes(), VMCALL_SERVICE_MIGTD_GUID) - .ok_or(MigrationResult::InvalidParameter)?; + #[cfg(not(feature = "virtio-serial"))] + { + use vsock::{stream::VsockStream, VsockAddr}; + // Establish the vsock connection with host + let mut vsock = VsockStream::new()?; + vsock + .connect(&VsockAddr::new( + info.mig_socket_info.mig_td_cid as u32, + info.mig_socket_info.mig_channel_port, + )) + .await?; + transport = vsock; + }; - tdx::tdvmcall_service(cmd_mem.as_bytes(), rsp_mem.as_mut_bytes(), 0, 0)?; + let mut remote_information = ExchangeInformation::default(); + let mut exchange_information = exchange_info(&info)?; - let private_mem = rsp_mem.copy_to_private_shadow(); + // Establish TLS layer connection and negotiate the MSK + if info.is_src() { + // TLS client + let mut ratls_client = + ratls::client(transport).map_err(|_| MigrationResult::SecureSessionError)?; - // Parse the response data - // Check the GUID of the reponse - let rsp = VmcallServiceResponse::try_read(private_mem) - .ok_or(MigrationResult::InvalidParameter)?; - if rsp.read_guid() != VMCALL_SERVICE_MIGTD_GUID.as_bytes() { - return Err(MigrationResult::InvalidParameter); + // MigTD-S send Migration Session Forward key to peer + with_timeout( + TLS_TIMEOUT, + ratls_client.write(exchange_information.as_bytes()), + ) + .await??; + let size = with_timeout( + TLS_TIMEOUT, + ratls_client.read(remote_information.as_bytes_mut()), + ) + .await??; + if size < size_of::() { + return Err(MigrationResult::NetworkError); } - let query = rsp - .read_data::(0) - .ok_or(MigrationResult::InvalidParameter)?; + } else { + // TLS server + let mut ratls_server = + ratls::server(transport).map_err(|_| MigrationResult::SecureSessionError)?; - // Ensure the response matches the command - if query.command != MIG_COMMAND_REPORT_STATUS { - return Err(MigrationResult::InvalidParameter); + with_timeout( + TLS_TIMEOUT, + ratls_server.write(exchange_information.as_bytes()), + ) + .await??; + let size = with_timeout( + TLS_TIMEOUT, + ratls_server.read(remote_information.as_bytes_mut()), + ) + .await??; + if size < size_of::() { + return Err(MigrationResult::NetworkError); } - Ok(()) } - #[cfg(feature = "main")] - async fn migrate(info: &MigrationInformation) -> Result<()> { - use crate::driver::ticks::with_timeout; - use core::time::Duration; - - const TLS_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds - - let transport; - #[cfg(feature = "virtio-serial")] - { - use virtio_serial::VirtioSerialPort; - const VIRTIO_SERIAL_PORT_ID: u32 = 1; - - let port = VirtioSerialPort::new(VIRTIO_SERIAL_PORT_ID); - port.open()?; - transport = port; - }; - - #[cfg(not(feature = "virtio-serial"))] - { - use vsock::{stream::VsockStream, VsockAddr}; - // Establish the vsock connection with host - let mut vsock = VsockStream::new()?; - vsock - .connect(&VsockAddr::new( - info.mig_socket_info.mig_td_cid as u32, - info.mig_socket_info.mig_channel_port, - )) - .await?; - transport = vsock; - }; - - let mut remote_information = ExchangeInformation::default(); - let mut exchange_information = MigrationSession::exchange_info(&info)?; - - // Establish TLS layer connection and negotiate the MSK - if info.is_src() { - // TLS client - let mut ratls_client = - ratls::client(transport).map_err(|_| MigrationResult::SecureSessionError)?; - - // MigTD-S send Migration Session Forward key to peer - with_timeout( - TLS_TIMEOUT, - ratls_client.write(exchange_information.as_bytes()), - ) - .await??; - let size = with_timeout( - TLS_TIMEOUT, - ratls_client.read(remote_information.as_bytes_mut()), - ) - .await??; - if size < size_of::() { - return Err(MigrationResult::NetworkError); - } - } else { - // TLS server - let mut ratls_server = - ratls::server(transport).map_err(|_| MigrationResult::SecureSessionError)?; - - with_timeout( - TLS_TIMEOUT, - ratls_server.write(exchange_information.as_bytes()), - ) - .await??; - let size = with_timeout( - TLS_TIMEOUT, - ratls_server.read(remote_information.as_bytes_mut()), - ) - .await??; - if size < size_of::() { - return Err(MigrationResult::NetworkError); - } - } + let mig_ver = cal_mig_version(info.is_src(), &exchange_information, &remote_information)?; + set_mig_version(info, mig_ver)?; + write_msk(&info.mig_info, &remote_information.key)?; - let mig_ver = cal_mig_version(info.is_src(), &exchange_information, &remote_information)?; - set_mig_version(info, mig_ver)?; - MigrationSession::write_msk(&info.mig_info, &remote_information.key)?; + log::info!("Set MSK and report status\n"); + exchange_information.key.clear(); + remote_information.key.clear(); - log::info!("Set MSK and report status\n"); - exchange_information.key.clear(); - remote_information.key.clear(); + Ok(()) +} - Ok(()) - } +fn exchange_info(info: &MigrationInformation) -> Result { + let mut exchange_info = ExchangeInformation::default(); + read_msk(&info.mig_info, &mut exchange_info.key)?; - fn exchange_info(info: &MigrationInformation) -> Result { - let mut exchange_info = ExchangeInformation::default(); - MigrationSession::read_msk(&info.mig_info, &mut exchange_info.key)?; + let (field_min, field_max) = if info.is_src() { + (GSM_FIELD_MIN_EXPORT_VERSION, GSM_FIELD_MAX_EXPORT_VERSION) + } else { + (GSM_FIELD_MIN_IMPORT_VERSION, GSM_FIELD_MAX_IMPORT_VERSION) + }; + let min_version = tdcall_sys_rd(field_min)?.1; + let max_version = tdcall_sys_rd(field_max)?.1; + if min_version > u16::MAX as u64 || max_version > u16::MAX as u64 { + return Err(MigrationResult::InvalidParameter); + } + exchange_info.min_ver = min_version as u16; + exchange_info.max_ver = max_version as u16; - let (field_min, field_max) = if info.is_src() { - (GSM_FIELD_MIN_EXPORT_VERSION, GSM_FIELD_MAX_EXPORT_VERSION) - } else { - (GSM_FIELD_MIN_IMPORT_VERSION, GSM_FIELD_MAX_IMPORT_VERSION) - }; - let min_version = tdcall_sys_rd(field_min)?.1; - let max_version = tdcall_sys_rd(field_max)?.1; - if min_version > u16::MAX as u64 || max_version > u16::MAX as u64 { - return Err(MigrationResult::InvalidParameter); - } - exchange_info.min_ver = min_version as u16; - exchange_info.max_ver = max_version as u16; + Ok(exchange_info) +} - Ok(exchange_info) +fn read_msk(mig_info: &MigtdMigrationInformation, msk: &mut MigrationSessionKey) -> Result<()> { + for idx in 0..msk.fields.len() { + let ret = tdx::tdcall_servtd_rd( + mig_info.binding_handle, + TDCS_FIELD_MIG_ENC_KEY + idx as u64, + &mig_info.target_td_uuid, + )?; + msk.fields[idx] = ret.content; } + Ok(()) +} - fn read_msk(mig_info: &MigtdMigrationInformation, msk: &mut MigrationSessionKey) -> Result<()> { - for idx in 0..msk.fields.len() { - let ret = tdx::tdcall_servtd_rd( - mig_info.binding_handle, - TDCS_FIELD_MIG_ENC_KEY + idx as u64, - &mig_info.target_td_uuid, - )?; - msk.fields[idx] = ret.content; - } - Ok(()) +fn write_msk(mig_info: &MigtdMigrationInformation, msk: &MigrationSessionKey) -> Result<()> { + for idx in 0..msk.fields.len() { + tdx::tdcall_servtd_wr( + mig_info.binding_handle, + TDCS_FIELD_MIG_DEC_KEY + idx as u64, + msk.fields[idx], + &mig_info.target_td_uuid, + ) + .map_err(|_| MigrationResult::TdxModuleError)?; } - fn write_msk(mig_info: &MigtdMigrationInformation, msk: &MigrationSessionKey) -> Result<()> { - for idx in 0..msk.fields.len() { - tdx::tdcall_servtd_wr( - mig_info.binding_handle, - TDCS_FIELD_MIG_DEC_KEY + idx as u64, - msk.fields[idx], - &mig_info.target_td_uuid, - ) - .map_err(|_| MigrationResult::TdxModuleError)?; - } + Ok(()) +} - Ok(()) - } +fn read_mig_info(hob: &[u8]) -> Option { + let mig_info_hob = + hob_lib::get_next_extension_guid_hob(hob, MIGRATION_INFORMATION_HOB_GUID.as_bytes())?; - fn read_mig_info(hob: &[u8]) -> Option { - let mig_info_hob = - hob_lib::get_next_extension_guid_hob(hob, MIGRATION_INFORMATION_HOB_GUID.as_bytes())?; - - let mig_info = hob_lib::get_guid_data(mig_info_hob)? - .pread::(0) - .ok()?; - - let mig_socket_hob = - hob_lib::get_next_extension_guid_hob(hob, STREAM_SOCKET_INFO_HOB_GUID.as_bytes())?; - - let mig_socket_info = hob_lib::get_guid_data(mig_socket_hob)? - .pread::(0) - .ok()?; - - // Migration Information is optional here - let mut mig_policy = None; - if let Some(policy_info_hob) = - hob_lib::get_next_extension_guid_hob(hob, MIGPOLICY_HOB_GUID.as_bytes()) - { - if let Some(policy_raw) = hob_lib::get_guid_data(policy_info_hob) { - let policy_header = policy_raw.pread::(0).ok()?; - let mut policy_data: Vec = Vec::new(); - let offset = size_of::(); - policy_data.extend_from_slice( - &policy_raw[offset..offset + policy_header.mig_policy_size as usize], - ); - mig_policy = Some(MigtdMigpolicy { - header: policy_header, - mig_policy: policy_data, - }); - } - } + let mig_info = hob_lib::get_guid_data(mig_info_hob)? + .pread::(0) + .ok()?; - let mig_info = MigrationInformation { - mig_info, - mig_socket_info, - mig_policy, - }; + let mig_socket_hob = + hob_lib::get_next_extension_guid_hob(hob, STREAM_SOCKET_INFO_HOB_GUID.as_bytes())?; - Some(mig_info) + let mig_socket_info = hob_lib::get_guid_data(mig_socket_hob)? + .pread::(0) + .ok()?; + + // Migration Information is optional here + let mut mig_policy = None; + if let Some(policy_info_hob) = + hob_lib::get_next_extension_guid_hob(hob, MIGPOLICY_HOB_GUID.as_bytes()) + { + if let Some(policy_raw) = hob_lib::get_guid_data(policy_info_hob) { + let policy_header = policy_raw.pread::(0).ok()?; + let mut policy_data: Vec = Vec::new(); + let offset = size_of::(); + policy_data.extend_from_slice( + &policy_raw[offset..offset + policy_header.mig_policy_size as usize], + ); + mig_policy = Some(MigtdMigpolicy { + header: policy_header, + mig_policy: policy_data, + }); + } } + + let mig_info = MigrationInformation { + mig_info, + mig_socket_info, + mig_policy, + }; + + Some(mig_info) } /// Used to read a TDX Module global-scope metadata field. diff --git a/tests/test-td-payload/src/testservice.rs b/tests/test-td-payload/src/testservice.rs index 9e541b47..c938d71d 100644 --- a/tests/test-td-payload/src/testservice.rs +++ b/tests/test-td-payload/src/testservice.rs @@ -10,7 +10,7 @@ use core::ffi::c_void; use td_payload::print; use test_td_payload::{TestCase, TestResult}; -use migtd::migration::session::MigrationSession; +use migtd::migration::session::query; use serde::{Deserialize, Serialize}; @@ -27,7 +27,7 @@ pub struct Tdservice { impl Tdservice { fn test_query(&mut self) -> TestResult { // Query the capability of VMM - if MigrationSession::query().is_err() { + if query().is_err() { print!("Migration is not supported by VMM"); return TestResult::Fail; }