diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index bea91bf428..cd3a0aef64 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -55,7 +55,15 @@ use crate::{ stats::{Stats, StatsCell}, stream_id::StreamType, streams::{SendOrder, Streams}, - tparams::{self, TransportParameters, TransportParametersHandler}, + tparams::{ + self, + TransportParameterId::{ + self, AckDelayExponent, ActiveConnectionIdLimit, DisableMigration, GreaseQuicBit, + InitialSourceConnectionId, MaxAckDelay, MaxDatagramFrameSize, MinAckDelay, + OriginalDestinationConnectionId, RetrySourceConnectionId, StatelessResetToken, + }, + TransportParameters, TransportParametersHandler, + }, tracking::{AckTracker, PacketNumberSpace, RecvdPackets}, version::{Version, WireVersion}, AppError, CloseReason, Error, Res, StreamId, @@ -373,10 +381,8 @@ impl Connection { let mut cid_manager = ConnectionIdManager::new(cid_generator, local_initial_source_cid.clone()); let mut tps = conn_params.create_transport_parameter(role, &mut cid_manager)?; - tps.local.set_bytes( - tparams::INITIAL_SOURCE_CONNECTION_ID, - local_initial_source_cid.to_vec(), - ); + tps.local + .set_bytes(InitialSourceConnectionId, local_initial_source_cid.to_vec()); let tphandler = Rc::new(RefCell::new(tps)); let crypto = Crypto::new( @@ -498,7 +504,7 @@ impl Connection { #[cfg(test)] pub fn set_local_tparam( &self, - tp: tparams::TransportParameterId, + tp: TransportParameterId, value: tparams::TransportParameter, ) -> Res<()> { if *self.state() == State::Init { @@ -525,8 +531,8 @@ impl Connection { qtrace!("[{self}] Retry CIDs: odcid={odcid} remote={remote_cid} retry={retry_cid}"); // We advertise "our" choices in transport parameters. let local_tps = &mut self.tps.borrow_mut().local; - local_tps.set_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID, odcid.to_vec()); - local_tps.set_bytes(tparams::RETRY_SOURCE_CONNECTION_ID, retry_cid.to_vec()); + local_tps.set_bytes(OriginalDestinationConnectionId, odcid.to_vec()); + local_tps.set_bytes(RetrySourceConnectionId, retry_cid.to_vec()); // ...and save their choices for later validation. self.remote_initial_source_cid = Some(remote_cid); @@ -536,7 +542,7 @@ impl Connection { self.tps .borrow() .local - .get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID) + .get_bytes(RetrySourceConnectionId) .is_some() } @@ -1407,10 +1413,10 @@ impl Connection { // This has to happen prior to processing the packet so that // the TLS handshake has all it needs. if !self.retry_sent() { - self.tps.borrow_mut().local.set_bytes( - tparams::ORIGINAL_DESTINATION_CONNECTION_ID, - packet.dcid().to_vec(), - ); + self.tps + .borrow_mut() + .local + .set_bytes(OriginalDestinationConnectionId, packet.dcid().to_vec()); } } (PacketType::VersionNegotiation, State::WaitInitial, Role::Client) => { @@ -1890,12 +1896,7 @@ impl Connection { if !matches!(self.state(), State::Confirmed) { return Err(Error::InvalidMigration); } - if self - .tps - .borrow() - .remote() - .get_empty(tparams::DISABLE_MIGRATION) - { + if self.tps.borrow().remote().get_empty(DisableMigration) { return Err(Error::InvalidMigration); } @@ -2113,9 +2114,9 @@ impl Connection { fn can_grease_quic_bit(&self) -> bool { let tph = self.tps.borrow(); if let Some(r) = &tph.remote { - r.get_empty(tparams::GREASE_QUIC_BIT) + r.get_empty(GreaseQuicBit) } else if let Some(r) = &tph.remote_0rtt { - r.get_empty(tparams::GREASE_QUIC_BIT) + r.get_empty(GreaseQuicBit) } else { false } @@ -2621,18 +2622,14 @@ impl Connection { .tps .borrow() .remote() - .get_integer(tparams::IDLE_TIMEOUT); + .get_integer(TransportParameterId::IdleTimeout); if peer_timeout > 0 { self.idle_timeout .set_peer_timeout(Duration::from_millis(peer_timeout)); } - self.quic_datagrams.set_remote_datagram_size( - self.tps - .borrow() - .remote() - .get_integer(tparams::MAX_DATAGRAM_FRAME_SIZE), - ); + self.quic_datagrams + .set_remote_datagram_size(self.tps.borrow().remote().get_integer(MaxDatagramFrameSize)); } #[must_use] @@ -2661,18 +2658,16 @@ impl Connection { return Err(Error::TransportParameterError); } - let reset_token = remote - .get_bytes(tparams::STATELESS_RESET_TOKEN) - .map_or_else( - || Ok(ConnectionIdEntry::random_srt()), - |token| <[u8; 16]>::try_from(token).map_err(|_| Error::TransportParameterError), - )?; + let reset_token = remote.get_bytes(StatelessResetToken).map_or_else( + || Ok(ConnectionIdEntry::random_srt()), + |token| <[u8; 16]>::try_from(token).map_err(|_| Error::TransportParameterError), + )?; let path = self.paths.primary().ok_or(Error::NoAvailablePath)?; path.borrow_mut().set_reset_token(reset_token); - let max_ad = Duration::from_millis(remote.get_integer(tparams::MAX_ACK_DELAY)); - let min_ad = if remote.has_value(tparams::MIN_ACK_DELAY) { - let min_ad = Duration::from_micros(remote.get_integer(tparams::MIN_ACK_DELAY)); + let max_ad = Duration::from_millis(remote.get_integer(MaxAckDelay)); + let min_ad = if remote.has_value(MinAckDelay) { + let min_ad = Duration::from_micros(remote.get_integer(MinAckDelay)); if min_ad > max_ad { return Err(Error::TransportParameterError); } @@ -2683,7 +2678,7 @@ impl Connection { path.borrow_mut() .set_ack_delay(max_ad, min_ad, self.conn_params.get_ack_ratio()); - let max_active_cids = remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT); + let max_active_cids = remote.get_integer(ActiveConnectionIdLimit); self.cid_manager.set_limit(max_active_cids); } self.set_initial_limits(); @@ -2695,7 +2690,7 @@ impl Connection { let tph = self.tps.borrow(); let remote_tps = tph.remote.as_ref().ok_or(Error::TransportParameterError)?; - let tp = remote_tps.get_bytes(tparams::INITIAL_SOURCE_CONNECTION_ID); + let tp = remote_tps.get_bytes(InitialSourceConnectionId); if self .remote_initial_source_cid .as_ref() @@ -2711,7 +2706,7 @@ impl Connection { } if self.role == Role::Client { - let tp = remote_tps.get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID); + let tp = remote_tps.get_bytes(OriginalDestinationConnectionId); if self .original_destination_cid .as_ref() @@ -2726,7 +2721,7 @@ impl Connection { return Err(Error::ProtocolViolation); } - let tp = remote_tps.get_bytes(tparams::RETRY_SOURCE_CONNECTION_ID); + let tp = remote_tps.get_bytes(RetrySourceConnectionId); let expected = if let AddressValidationInfo::Retry { retry_source_cid, .. } = &self.address_validation @@ -3120,7 +3115,7 @@ impl Connection { self.tps.borrow().remote.as_ref().map_or_else( || Ok(Duration::default()), |r| { - let exponent = u32::try_from(r.get_integer(tparams::ACK_DELAY_EXPONENT))?; + let exponent = u32::try_from(r.get_integer(AckDelayExponent))?; // ACK_DELAY_EXPONENT > 20 is invalid per RFC9000. We already checked that in // TransportParameter::decode. let corrected = if v.leading_zeros() >= exponent { diff --git a/neqo-transport/src/connection/params.rs b/neqo-transport/src/connection/params.rs index 7adfadc11e..8578baa2cd 100644 --- a/neqo-transport/src/connection/params.rs +++ b/neqo-transport/src/connection/params.rs @@ -12,7 +12,16 @@ use crate::{ recv_stream::RECV_BUFFER_SIZE, rtt::GRANULARITY, stream_id::StreamType, - tparams::{self, PreferredAddress, TransportParameter, TransportParametersHandler}, + tparams::{ + PreferredAddress, TransportParameter, + TransportParameterId::{ + self, ActiveConnectionIdLimit, DisableMigration, GreaseQuicBit, InitialMaxData, + InitialMaxStreamDataBidiLocal, InitialMaxStreamDataBidiRemote, InitialMaxStreamDataUni, + InitialMaxStreamsBidi, InitialMaxStreamsUni, MaxAckDelay, MaxDatagramFrameSize, + MinAckDelay, + }, + TransportParametersHandler, + }, tracking::DEFAULT_ACK_DELAY, version::{Version, VersionConfig}, CongestionControlAlgorithm, Res, @@ -408,52 +417,45 @@ impl ConnectionParameters { let mut tps = TransportParametersHandler::new(role, self.versions.clone()); // default parameters tps.local.set_integer( - tparams::ACTIVE_CONNECTION_ID_LIMIT, + ActiveConnectionIdLimit, u64::try_from(LOCAL_ACTIVE_CID_LIMIT)?, ); if self.disable_migration { - tps.local.set_empty(tparams::DISABLE_MIGRATION); + tps.local.set_empty(DisableMigration); } if self.grease { - tps.local.set_empty(tparams::GREASE_QUIC_BIT); + tps.local.set_empty(GreaseQuicBit); } - tps.local.set_integer( - tparams::MAX_ACK_DELAY, - u64::try_from(DEFAULT_ACK_DELAY.as_millis())?, - ); - tps.local.set_integer( - tparams::MIN_ACK_DELAY, - u64::try_from(GRANULARITY.as_micros())?, - ); + tps.local + .set_integer(MaxAckDelay, u64::try_from(DEFAULT_ACK_DELAY.as_millis())?); + tps.local + .set_integer(MinAckDelay, u64::try_from(GRANULARITY.as_micros())?); // set configurable parameters - tps.local - .set_integer(tparams::INITIAL_MAX_DATA, self.max_data); + tps.local.set_integer(InitialMaxData, self.max_data); tps.local.set_integer( - tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, + InitialMaxStreamDataBidiLocal, self.max_stream_data_bidi_local, ); tps.local.set_integer( - tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, + InitialMaxStreamDataBidiRemote, self.max_stream_data_bidi_remote, ); - tps.local.set_integer( - tparams::INITIAL_MAX_STREAM_DATA_UNI, - self.max_stream_data_uni, - ); tps.local - .set_integer(tparams::INITIAL_MAX_STREAMS_BIDI, self.max_streams_bidi); + .set_integer(InitialMaxStreamDataUni, self.max_stream_data_uni); + tps.local + .set_integer(InitialMaxStreamsBidi, self.max_streams_bidi); tps.local - .set_integer(tparams::INITIAL_MAX_STREAMS_UNI, self.max_streams_uni); + .set_integer(InitialMaxStreamsUni, self.max_streams_uni); tps.local.set_integer( - tparams::IDLE_TIMEOUT, + TransportParameterId::IdleTimeout, u64::try_from(self.idle_timeout.as_millis()).unwrap_or(0), ); if let PreferredAddressConfig::Address(preferred) = &self.preferred_address { if role == Role::Server { let (cid, srt) = cid_manager.preferred_address_cid()?; tps.local.set( - tparams::PREFERRED_ADDRESS, + TransportParameterId::PreferredAddress, TransportParameter::PreferredAddress { v4: preferred.ipv4(), v6: preferred.ipv6(), @@ -464,7 +466,7 @@ impl ConnectionParameters { } } tps.local - .set_integer(tparams::MAX_DATAGRAM_FRAME_SIZE, self.datagram_size); + .set_integer(MaxDatagramFrameSize, self.datagram_size); Ok(tps) } } diff --git a/neqo-transport/src/connection/tests/close.rs b/neqo-transport/src/connection/tests/close.rs index 4f7ef59171..98f19e7537 100644 --- a/neqo-transport/src/connection/tests/close.rs +++ b/neqo-transport/src/connection/tests/close.rs @@ -13,7 +13,7 @@ use super::{ connect, connect_force_idle, default_client, default_server, send_something, }; use crate::{ - tparams::{self, TransportParameter}, + tparams::{TransportParameter, TransportParameterId::StatelessResetToken}, AppError, CloseReason, Error, ERROR_APPLICATION_CLOSE, }; @@ -213,10 +213,7 @@ fn stateless_reset_client() { let mut client = default_client(); let mut server = default_server(); server - .set_local_tparam( - tparams::STATELESS_RESET_TOKEN, - TransportParameter::Bytes(vec![77; 16]), - ) + .set_local_tparam(StatelessResetToken, TransportParameter::Bytes(vec![77; 16])) .unwrap(); connect_force_idle(&mut client, &mut server); diff --git a/neqo-transport/src/connection/tests/handshake.rs b/neqo-transport/src/connection/tests/handshake.rs index 9e9a40edad..fa1620f968 100644 --- a/neqo-transport/src/connection/tests/handshake.rs +++ b/neqo-transport/src/connection/tests/handshake.rs @@ -35,7 +35,7 @@ use crate::{ events::ConnectionEvent, server::ValidateAddress, stats::FrameStats, - tparams::{self, TransportParameter, MIN_ACK_DELAY}, + tparams::{TransportParameter, TransportParameterId::*}, tracking::DEFAULT_ACK_DELAY, CloseReason, ConnectionParameters, Error, Pmtud, StreamType, Version, }; @@ -840,7 +840,9 @@ fn anti_amplification() { // With a gigantic transport parameter, the server is unable to complete // the handshake within the amplification limit. let very_big = TransportParameter::Bytes(vec![0; Pmtud::default_plpmtu(DEFAULT_ADDR.ip()) * 3]); - server.set_local_tparam(0xce16, very_big).unwrap(); + server + .set_local_tparam(TestTransportParameter, very_big) + .unwrap(); let c_init = client.process_output(now).dgram(); now += DEFAULT_RTT / 2; @@ -1089,7 +1091,7 @@ fn bad_min_ack_delay() { let mut server = default_server(); let max_ad = u64::try_from(DEFAULT_ACK_DELAY.as_micros()).unwrap(); server - .set_local_tparam(MIN_ACK_DELAY, TransportParameter::Integer(max_ad + 1)) + .set_local_tparam(MinAckDelay, TransportParameter::Integer(max_ad + 1)) .unwrap(); let mut client = default_client(); @@ -1371,7 +1373,7 @@ fn grease_quic_bit_transport_parameter() { .remote .as_ref() .unwrap() - .get_empty(tparams::GREASE_QUIC_BIT) + .get_empty(GreaseQuicBit) } for client_grease in [true, false] { diff --git a/neqo-transport/src/connection/tests/idle.rs b/neqo-transport/src/connection/tests/idle.rs index 8e7568304d..bf11bc6401 100644 --- a/neqo-transport/src/connection/tests/idle.rs +++ b/neqo-transport/src/connection/tests/idle.rs @@ -19,7 +19,7 @@ use crate::{ packet::PacketBuilder, stats::FrameStats, stream_id::{StreamId, StreamType}, - tparams::{self, TransportParameter}, + tparams::{TransportParameter, TransportParameterId}, tracking::PacketNumberSpace, }; @@ -97,7 +97,7 @@ fn asymmetric_idle_timeout() { .tps .borrow_mut() .local - .set_integer(tparams::IDLE_TIMEOUT, LOWER_TIMEOUT_MS); + .set_integer(TransportParameterId::IdleTimeout, LOWER_TIMEOUT_MS); server.idle_timeout = IdleTimeout::new(LOWER_TIMEOUT); // Now connect and force idleness manually. @@ -135,7 +135,7 @@ fn tiny_idle_timeout() { // Overwrite the default at the server. server .set_local_tparam( - tparams::IDLE_TIMEOUT, + TransportParameterId::IdleTimeout, TransportParameter::Integer(LOWER_TIMEOUT_MS), ) .unwrap(); diff --git a/neqo-transport/src/connection/tests/migration.rs b/neqo-transport/src/connection/tests/migration.rs index 8c01c037e6..7fa12eb840 100644 --- a/neqo-transport/src/connection/tests/migration.rs +++ b/neqo-transport/src/connection/tests/migration.rs @@ -33,7 +33,7 @@ use crate::{ path::MAX_PATH_PROBES, pmtud::Pmtud, stats::FrameStats, - tparams::{self, PreferredAddress, TransportParameter}, + tparams::{PreferredAddress, TransportParameter, TransportParameterId}, CloseReason, ConnectionId, ConnectionIdDecoder as _, ConnectionIdGenerator, ConnectionIdRef, ConnectionParameters, EmptyConnectionIdGenerator, Error, MIN_INITIAL_PACKET_SIZE, }; @@ -904,7 +904,7 @@ fn preferred_address_server_empty_cid() { server .set_local_tparam( - tparams::PREFERRED_ADDRESS, + TransportParameterId::PreferredAddress, TransportParameter::Bytes(SAMPLE_PREFERRED_ADDRESS.to_vec()), ) .unwrap(); @@ -925,7 +925,7 @@ fn preferred_address_client() { client .set_local_tparam( - tparams::PREFERRED_ADDRESS, + TransportParameterId::PreferredAddress, TransportParameter::Bytes(SAMPLE_PREFERRED_ADDRESS.to_vec()), ) .unwrap(); diff --git a/neqo-transport/src/connection/tests/mod.rs b/neqo-transport/src/connection/tests/mod.rs index 79369e885b..c249909fa3 100644 --- a/neqo-transport/src/connection/tests/mod.rs +++ b/neqo-transport/src/connection/tests/mod.rs @@ -29,7 +29,7 @@ use crate::{ pmtud::Pmtud, recovery::ACK_ONLY_SIZE_LIMIT, stats::{FrameStats, Stats, MAX_PTO_COUNTS}, - tparams::{DISABLE_MIGRATION, GREASE_QUIC_BIT}, + tparams::TransportParameterId::*, ConnectionIdDecoder, ConnectionIdGenerator, ConnectionParameters, EmptyConnectionIdGenerator, Error, StreamId, StreamType, Version, MIN_INITIAL_PACKET_SIZE, }; @@ -729,7 +729,7 @@ fn create_server() { fn tp_grease() { for enable in [true, false] { let client = new_client(ConnectionParameters::default().grease(enable)); - let grease = client.tps.borrow_mut().local.get_empty(GREASE_QUIC_BIT); + let grease = client.tps.borrow_mut().local.get_empty(GreaseQuicBit); assert_eq!(enable, grease); } } @@ -738,7 +738,7 @@ fn tp_grease() { fn tp_disable_migration() { for disable in [true, false] { let client = new_client(ConnectionParameters::default().disable_migration(disable)); - let disable_migration = client.tps.borrow_mut().local.get_empty(DISABLE_MIGRATION); + let disable_migration = client.tps.borrow_mut().local.get_empty(DisableMigration); assert_eq!(disable, disable_migration); } } diff --git a/neqo-transport/src/connection/tests/recovery.rs b/neqo-transport/src/connection/tests/recovery.rs index 8b28712bb2..c1558fce1d 100644 --- a/neqo-transport/src/connection/tests/recovery.rs +++ b/neqo-transport/src/connection/tests/recovery.rs @@ -29,7 +29,7 @@ use crate::{ }, rtt::GRANULARITY, stats::MAX_PTO_COUNTS, - tparams::TransportParameter, + tparams::{TransportParameter, TransportParameterId::*}, tracking::DEFAULT_ACK_DELAY, CloseReason, Error, Pmtud, StreamType, }; @@ -404,7 +404,9 @@ fn handshake_ack_pto() { // This is a greasing transport parameter, and large enough that the // server needs to send two Handshake packets. let big = TransportParameter::Bytes(vec![0; Pmtud::default_plpmtu(DEFAULT_ADDR.ip())]); - server.set_local_tparam(0xce16, big).unwrap(); + server + .set_local_tparam(TestTransportParameter, big) + .unwrap(); let c1 = client.process_output(now).dgram(); diff --git a/neqo-transport/src/connection/tests/stream.rs b/neqo-transport/src/connection/tests/stream.rs index 2921603471..e775cbb8ec 100644 --- a/neqo-transport/src/connection/tests/stream.rs +++ b/neqo-transport/src/connection/tests/stream.rs @@ -24,7 +24,7 @@ use crate::{ recv_stream::RECV_BUFFER_SIZE, send_stream::{OrderGroup, SendStreamState, SEND_BUFFER_SIZE}, streams::{SendOrder, StreamOrder}, - tparams::{self, TransportParameter}, + tparams::{TransportParameter, TransportParameterId::*}, CloseReason, Connection, ConnectionParameters, Error, StreamId, StreamType, }; @@ -397,7 +397,7 @@ fn max_data() { server .set_local_tparam( - tparams::INITIAL_MAX_DATA, + InitialMaxData, TransportParameter::Integer(u64::try_from(SMALL_MAX_DATA).unwrap()), ) .unwrap(); diff --git a/neqo-transport/src/connection/tests/vn.rs b/neqo-transport/src/connection/tests/vn.rs index 7d3e6dc6d7..753baa4b54 100644 --- a/neqo-transport/src/connection/tests/vn.rs +++ b/neqo-transport/src/connection/tests/vn.rs @@ -16,7 +16,7 @@ use super::{ }; use crate::{ packet::PACKET_BIT_LONG, - tparams::{self, TransportParameter}, + tparams::{TransportParameter, TransportParameterId::*}, ConnectionParameters, Error, Version, MIN_INITIAL_PACKET_SIZE, }; @@ -240,7 +240,7 @@ fn compatible_upgrade_large_initial() { let mut client = new_client(params.clone()); client .set_local_tparam( - 0x0845_de37_00ac_a5f9, + TestTransportParameter, TransportParameter::Bytes(vec![0; 2048]), ) .unwrap(); @@ -368,7 +368,7 @@ fn invalid_current_version_client() { assert_ne!(OTHER_VERSION, client.version()); client .set_local_tparam( - tparams::VERSION_INFORMATION, + VersionInformation, TransportParameter::Versions { current: OTHER_VERSION.wire_version(), other: Version::all() @@ -404,7 +404,7 @@ fn invalid_current_version_server() { assert!(!Version::default().is_compatible(OTHER_VERSION)); server .set_local_tparam( - tparams::VERSION_INFORMATION, + VersionInformation, TransportParameter::Versions { current: OTHER_VERSION.wire_version(), other: vec![OTHER_VERSION.wire_version()], @@ -430,7 +430,7 @@ fn no_compatible_version() { assert_ne!(OTHER_VERSION, client.version()); client .set_local_tparam( - tparams::VERSION_INFORMATION, + VersionInformation, TransportParameter::Versions { current: Version::default().wire_version(), other: vec![OTHER_VERSION.wire_version()], diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index 44981462e9..361b3b7104 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -9,13 +9,16 @@ use std::{ cell::RefCell, cmp::{max, min}, - collections::HashMap, mem, ops::{Index, IndexMut, Range}, rc::Rc, time::Instant, }; +#[cfg(not(feature = "disable-encryption"))] +#[cfg(test)] +use enum_map::enum_map; +use enum_map::EnumMap; use neqo_common::{hex, hex_snip_middle, qdebug, qinfo, qtrace, Encoder, Role}; pub use neqo_crypto::Epoch; use neqo_crypto::{ @@ -793,7 +796,7 @@ impl CryptoDxAppData { /// get other keys, so those have fixed versions. #[derive(Debug, Default)] pub struct CryptoStates { - initials: HashMap, + initials: EnumMap>, handshake: Option, zero_rtt: Option, // One direction only! cipher: Cipher, @@ -806,6 +809,10 @@ pub struct CryptoStates { } impl CryptoStates { + fn initials_is_empty(&self) -> bool { + self.initials.values().flatten().count() == 0 + } + /// Select a `CryptoDxState` and `CryptoSpace` for the given `PacketNumberSpace`. /// This selects 0-RTT keys for `PacketNumberSpace::ApplicationData` if 1-RTT keys are /// not yet available. @@ -838,7 +845,7 @@ impl CryptoStates { ) -> Option<&'a mut CryptoDxState> { let tx = |k: Option<&'a mut CryptoState>| k.map(|dx| &mut dx.tx); match epoch { - Epoch::Initial => tx(self.initials.get_mut(&version)), + Epoch::Initial => tx(self.initials[version].as_mut()), Epoch::ZeroRtt => self .zero_rtt .as_mut() @@ -851,7 +858,7 @@ impl CryptoStates { pub fn tx<'a>(&'a self, version: Version, epoch: Epoch) -> Option<&'a CryptoDxState> { let tx = |k: Option<&'a CryptoState>| k.map(|dx| &dx.tx); match epoch { - Epoch::Initial => tx(self.initials.get(&version)), + Epoch::Initial => tx(self.initials[version].as_ref()), Epoch::ZeroRtt => self .zero_rtt .as_ref() @@ -896,7 +903,7 @@ impl CryptoStates { ) -> Option<&'a mut CryptoDxState> { let rx = |x: Option<&'a mut CryptoState>| x.map(|dx| &mut dx.rx); match epoch { - Epoch::Initial => rx(self.initials.get_mut(&version)), + Epoch::Initial => rx(self.initials[version].as_mut()), Epoch::ZeroRtt => self .zero_rtt .as_mut() @@ -925,7 +932,7 @@ impl CryptoStates { pub fn rx_pending(&self, space: Epoch) -> bool { match space { Epoch::Initial | Epoch::ZeroRtt => false, - Epoch::Handshake => self.handshake.is_none() && !self.initials.is_empty(), + Epoch::Handshake => self.handshake.is_none() && !self.initials_is_empty(), Epoch::ApplicationData => self.app_read.is_none(), } } @@ -954,14 +961,14 @@ impl CryptoStates { tx: CryptoDxState::new_initial(*v, CryptoDxDirection::Write, write, dcid)?, rx: CryptoDxState::new_initial(*v, CryptoDxDirection::Read, read, dcid)?, }; - if let Some(prev) = self.initials.get(v) { + if let Some(prev) = &self.initials[*v] { qinfo!( "[{self}] Continue packet numbers for initial after retry (write is {:?})", prev.rx.used_pn, ); initial.tx.continuation(&prev.tx)?; } - self.initials.insert(*v, initial); + self.initials[*v] = Some(initial); } Ok(()) } @@ -973,7 +980,7 @@ impl CryptoStates { /// not need the send keys if the packet is subsequently discarded, but /// the overall effort is small enough to write off. pub fn init_server(&mut self, version: Version, dcid: &[u8]) -> Res<()> { - if !self.initials.contains_key(&version) { + if self.initials[version].is_none() { self.init(&[version], Role::Server, dcid)?; } Ok(()) @@ -985,13 +992,12 @@ impl CryptoStates { // appease the borrow checker. // Note that on the server, we might not have initials for |orig| if it // was configured for |orig| and only |confirmed| Initial packets arrived. - if let Some(prev) = self.initials.remove(&orig) { - let next = self - .initials - .get_mut(&confirmed) + if let Some(prev) = self.initials[orig].take() { + let next = self.initials[confirmed] + .as_mut() .ok_or(Error::VersionNegotiation)?; next.tx.continuation(&prev.tx)?; - self.initials.insert(orig, prev); + self.initials[orig] = Some(prev); } } Ok(()) @@ -1019,7 +1025,7 @@ impl CryptoStates { pub fn discard(&mut self, space: PacketNumberSpace) -> bool { match space { PacketNumberSpace::Initial => { - let empty = self.initials.is_empty(); + let empty = self.initials_is_empty(); self.initials.clear(); !empty } @@ -1243,14 +1249,14 @@ impl CryptoStates { cipher: TLS_AES_128_GCM_SHA256, next_secret: hkdf::import_key(TLS_VERSION_1_3, &[0xaa; 32]).unwrap(), }; - let mut initials = HashMap::new(); - initials.insert( - Version::Version1, - CryptoState { - tx: CryptoDxState::test_default(), - rx: read(0), - }, - ); + let initials = enum_map! { + Version::Version1 => Some(CryptoState { + tx: CryptoDxState::test_default(), + rx: read(0), + }), + Version::Version2 => None, + Version::Draft29 => None, + }; Self { initials, handshake: None, @@ -1301,7 +1307,7 @@ impl CryptoStates { next_secret: secret.clone(), }; Self { - initials: HashMap::new(), + initials: EnumMap::default(), handshake: None, zero_rtt: None, cipher: TLS_CHACHA20_POLY1305_SHA256, diff --git a/neqo-transport/src/lib.rs b/neqo-transport/src/lib.rs index e9c8ac2dee..a5f3e8ac7f 100644 --- a/neqo-transport/src/lib.rs +++ b/neqo-transport/src/lib.rs @@ -141,6 +141,7 @@ pub enum Error { UnknownFrameType, VersionNegotiation, WrongRole, + UnknownTransportParameter, } impl Error { diff --git a/neqo-transport/src/qlog.rs b/neqo-transport/src/qlog.rs index 6ed4b510e0..273b3bb8a6 100644 --- a/neqo-transport/src/qlog.rs +++ b/neqo-transport/src/qlog.rs @@ -31,53 +31,70 @@ use crate::{ path::PathRef, recovery::SentPacket, stream_id::StreamType as NeqoStreamType, - tparams::{self, TransportParametersHandler}, + tparams::{ + TransportParameterId::{ + self, AckDelayExponent, ActiveConnectionIdLimit, DisableMigration, InitialMaxData, + InitialMaxStreamDataBidiLocal, InitialMaxStreamDataBidiRemote, InitialMaxStreamDataUni, + InitialMaxStreamsBidi, InitialMaxStreamsUni, MaxAckDelay, MaxUdpPayloadSize, + OriginalDestinationConnectionId, StatelessResetToken, + }, + TransportParametersHandler, + }, version::{Version, VersionConfig, WireVersion}, }; pub fn connection_tparams_set(qlog: &NeqoQlog, tph: &TransportParametersHandler, now: Instant) { - qlog.add_event_data_with_instant(|| { - let remote = tph.remote(); - #[allow(clippy::cast_possible_truncation)] // Nope. - let ev_data = EventData::TransportParametersSet( - qlog::events::quic::TransportParametersSet { - owner: None, - resumption_allowed: None, - early_data_enabled: None, - tls_cipher: None, - aead_tag_length: None, - original_destination_connection_id: remote - .get_bytes(tparams::ORIGINAL_DESTINATION_CONNECTION_ID) - .map(hex), - initial_source_connection_id: None, - retry_source_connection_id: None, - stateless_reset_token: remote.get_bytes(tparams::STATELESS_RESET_TOKEN).map(hex), - disable_active_migration: remote.get_empty(tparams::DISABLE_MIGRATION).then_some(true), - max_idle_timeout: Some(remote.get_integer(tparams::IDLE_TIMEOUT)), - max_udp_payload_size: Some(remote.get_integer(tparams::MAX_UDP_PAYLOAD_SIZE) as u32), - ack_delay_exponent: Some(remote.get_integer(tparams::ACK_DELAY_EXPONENT) as u16), - max_ack_delay: Some(remote.get_integer(tparams::MAX_ACK_DELAY) as u16), - active_connection_id_limit: Some(remote.get_integer(tparams::ACTIVE_CONNECTION_ID_LIMIT) as u32), - initial_max_data: Some(remote.get_integer(tparams::INITIAL_MAX_DATA)), - initial_max_stream_data_bidi_local: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL)), - initial_max_stream_data_bidi_remote: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE)), - initial_max_stream_data_uni: Some(remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI)), - initial_max_streams_bidi: Some(remote.get_integer(tparams::INITIAL_MAX_STREAMS_BIDI)), - initial_max_streams_uni: Some(remote.get_integer(tparams::INITIAL_MAX_STREAMS_UNI)), - preferred_address: remote.get_preferred_address().and_then(|(paddr, cid)| { - Some(qlog::events::quic::PreferredAddress { - ip_v4: paddr.ipv4()?.ip().to_string(), - ip_v6: paddr.ipv6()?.ip().to_string(), - port_v4: paddr.ipv4()?.port(), - port_v6: paddr.ipv6()?.port(), - connection_id: cid.connection_id().to_string(), - stateless_reset_token: hex(cid.reset_token()), - }) - }), - }); + qlog.add_event_data_with_instant( + || { + let remote = tph.remote(); + #[allow(clippy::cast_possible_truncation)] // Nope. + let ev_data = + EventData::TransportParametersSet(qlog::events::quic::TransportParametersSet { + owner: None, + resumption_allowed: None, + early_data_enabled: None, + tls_cipher: None, + aead_tag_length: None, + original_destination_connection_id: remote + .get_bytes(OriginalDestinationConnectionId) + .map(hex), + initial_source_connection_id: None, + retry_source_connection_id: None, + stateless_reset_token: remote.get_bytes(StatelessResetToken).map(hex), + disable_active_migration: remote.get_empty(DisableMigration).then_some(true), + max_idle_timeout: Some(remote.get_integer(TransportParameterId::IdleTimeout)), + max_udp_payload_size: Some(remote.get_integer(MaxUdpPayloadSize) as u32), + ack_delay_exponent: Some(remote.get_integer(AckDelayExponent) as u16), + max_ack_delay: Some(remote.get_integer(MaxAckDelay) as u16), + active_connection_id_limit: Some( + remote.get_integer(ActiveConnectionIdLimit) as u32 + ), + initial_max_data: Some(remote.get_integer(InitialMaxData)), + initial_max_stream_data_bidi_local: Some( + remote.get_integer(InitialMaxStreamDataBidiLocal), + ), + initial_max_stream_data_bidi_remote: Some( + remote.get_integer(InitialMaxStreamDataBidiRemote), + ), + initial_max_stream_data_uni: Some(remote.get_integer(InitialMaxStreamDataUni)), + initial_max_streams_bidi: Some(remote.get_integer(InitialMaxStreamsBidi)), + initial_max_streams_uni: Some(remote.get_integer(InitialMaxStreamsUni)), + preferred_address: remote.get_preferred_address().and_then(|(paddr, cid)| { + Some(qlog::events::quic::PreferredAddress { + ip_v4: paddr.ipv4()?.ip().to_string(), + ip_v6: paddr.ipv6()?.ip().to_string(), + port_v4: paddr.ipv4()?.port(), + port_v6: paddr.ipv6()?.port(), + connection_id: cid.connection_id().to_string(), + stateless_reset_token: hex(cid.reset_token()), + }) + }), + }); - Some(ev_data) - }, now); + Some(ev_data) + }, + now, + ); } pub fn server_connection_started(qlog: &NeqoQlog, path: &PathRef, now: Instant) { diff --git a/neqo-transport/src/send_stream.rs b/neqo-transport/src/send_stream.rs index fb62b3d047..142eeea3f4 100644 --- a/neqo-transport/src/send_stream.rs +++ b/neqo-transport/src/send_stream.rs @@ -32,7 +32,10 @@ use crate::{ stats::FrameStats, stream_id::StreamId, streams::SendOrder, - tparams::{self, TransportParameters}, + tparams::{ + TransportParameterId::{InitialMaxStreamDataBidiRemote, InitialMaxStreamDataUni}, + TransportParameters, + }, AppError, Error, Res, }; @@ -1761,9 +1764,9 @@ impl SendStreams { for (id, ss) in &mut self.map { let limit = if id.is_bidi() { assert!(!id.is_remote_initiated(Role::Client)); - remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE) + remote.get_integer(InitialMaxStreamDataBidiRemote) } else { - remote.get_integer(tparams::INITIAL_MAX_STREAM_DATA_UNI) + remote.get_integer(InitialMaxStreamDataUni) }; ss.set_max_stream_data(limit); } diff --git a/neqo-transport/src/streams.rs b/neqo-transport/src/streams.rs index a79195ac72..1b581bd895 100644 --- a/neqo-transport/src/streams.rs +++ b/neqo-transport/src/streams.rs @@ -18,7 +18,13 @@ use crate::{ send_stream::{SendStream, SendStreams, TransmissionPriority}, stats::FrameStats, stream_id::{StreamId, StreamType}, - tparams::{self, TransportParametersHandler}, + tparams::{ + TransportParameterId::{ + InitialMaxData, InitialMaxStreamDataBidiLocal, InitialMaxStreamDataBidiRemote, + InitialMaxStreamDataUni, InitialMaxStreamsBidi, InitialMaxStreamsUni, + }, + TransportParametersHandler, + }, ConnectionEvents, Error, Res, }; @@ -73,15 +79,9 @@ impl Streams { role: Role, events: ConnectionEvents, ) -> Self { - let limit_bidi = tps - .borrow() - .local - .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI); - let limit_uni = tps - .borrow() - .local - .get_integer(tparams::INITIAL_MAX_STREAMS_UNI); - let max_data = tps.borrow().local.get_integer(tparams::INITIAL_MAX_DATA); + let limit_bidi = tps.borrow().local.get_integer(InitialMaxStreamsBidi); + let limit_uni = tps.borrow().local.get_integer(InitialMaxStreamsUni); + let max_data = tps.borrow().local.get_integer(InitialMaxData); Self { role, tps, @@ -104,17 +104,11 @@ impl Streams { self.clear_streams(); debug_assert_eq!( self.remote_stream_limits[StreamType::BiDi].max_active(), - self.tps - .borrow() - .local - .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI) + self.tps.borrow().local.get_integer(InitialMaxStreamsBidi) ); debug_assert_eq!( self.remote_stream_limits[StreamType::UniDi].max_active(), - self.tps - .borrow() - .local - .get_integer(tparams::INITIAL_MAX_STREAMS_UNI) + self.tps.borrow().local.get_integer(InitialMaxStreamsUni) ); self.local_stream_limits = LocalStreamLimits::new(self.role); } @@ -356,8 +350,8 @@ impl Streams { // look at the local transport parameters for the // INITIAL_MAX_STREAM_DATA_BIDI_REMOTE value to decide how much this endpoint // will allow its peer to send. - StreamType::BiDi => tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, - StreamType::UniDi => tparams::INITIAL_MAX_STREAM_DATA_UNI, + StreamType::BiDi => InitialMaxStreamDataBidiRemote, + StreamType::UniDi => InitialMaxStreamDataUni, }; let recv_initial_max_stream_data = self.tps.borrow().local.get_integer(tp); @@ -386,7 +380,7 @@ impl Streams { .tps .borrow() .remote() - .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL); + .get_integer(InitialMaxStreamDataBidiLocal); self.send.insert( next_stream_id, SendStream::new( @@ -447,8 +441,8 @@ impl Streams { None => Err(Error::StreamLimitError), Some(new_id) => { let send_limit_tp = match st { - StreamType::UniDi => tparams::INITIAL_MAX_STREAM_DATA_UNI, - StreamType::BiDi => tparams::INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, + StreamType::UniDi => InitialMaxStreamDataUni, + StreamType::BiDi => InitialMaxStreamDataBidiRemote, }; let send_limit = self.tps.borrow().remote().get_integer(send_limit_tp); let stream = SendStream::new( @@ -469,7 +463,7 @@ impl Streams { .tps .borrow() .local - .get_integer(tparams::INITIAL_MAX_STREAM_DATA_BIDI_LOCAL); + .get_integer(InitialMaxStreamDataBidiLocal); self.recv.insert( new_id, @@ -506,14 +500,10 @@ impl Streams { self.tps .borrow() .remote() - .get_integer(tparams::INITIAL_MAX_STREAMS_BIDI), - ); - _ = self.local_stream_limits[StreamType::UniDi].update( - self.tps - .borrow() - .remote() - .get_integer(tparams::INITIAL_MAX_STREAMS_UNI), + .get_integer(InitialMaxStreamsBidi), ); + _ = self.local_stream_limits[StreamType::UniDi] + .update(self.tps.borrow().remote().get_integer(InitialMaxStreamsUni)); // As a client, there are two sets of initial limits for sending stream data. // If the second limit is higher and streams have been created, then @@ -522,12 +512,9 @@ impl Streams { self.send.update_initial_limit(self.tps.borrow().remote()); } - self.sender_fc.borrow_mut().update( - self.tps - .borrow() - .remote() - .get_integer(tparams::INITIAL_MAX_DATA), - ); + self.sender_fc + .borrow_mut() + .update(self.tps.borrow().remote().get_integer(InitialMaxData)); if self.local_stream_limits[StreamType::BiDi].available() > 0 { self.events.send_stream_creatable(StreamType::BiDi); diff --git a/neqo-transport/src/tparams.rs b/neqo-transport/src/tparams.rs index 57188ade84..956efe7d2d 100644 --- a/neqo-transport/src/tparams.rs +++ b/neqo-transport/src/tparams.rs @@ -8,17 +8,26 @@ use std::{ cell::RefCell, - collections::HashMap, + fmt::Display, net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}, rc::Rc, }; +use enum_map::{Enum, EnumMap}; use neqo_common::{hex, qdebug, qinfo, qtrace, Decoder, Encoder, Role}; use neqo_crypto::{ constants::{TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS}, ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult}, random, HandshakeMessage, ZeroRttCheckResult, ZeroRttChecker, }; +use TransportParameterId::{ + AckDelayExponent, ActiveConnectionIdLimit, DisableMigration, GreaseQuicBit, IdleTimeout, + InitialMaxData, InitialMaxStreamDataBidiLocal, InitialMaxStreamDataBidiRemote, + InitialMaxStreamDataUni, InitialMaxStreamsBidi, InitialMaxStreamsUni, + InitialSourceConnectionId, MaxAckDelay, MaxDatagramFrameSize, MaxUdpPayloadSize, MinAckDelay, + OriginalDestinationConnectionId, RetrySourceConnectionId, StatelessResetToken, + VersionInformation, +}; use crate::{ cid::{ConnectionId, ConnectionIdEntry, CONNECTION_ID_SEQNO_PREFERRED, MAX_CONNECTION_ID_LEN}, @@ -27,28 +36,78 @@ use crate::{ Error, Res, }; -pub type TransportParameterId = u64; -pub const ORIGINAL_DESTINATION_CONNECTION_ID: TransportParameterId = 0x00; -pub const IDLE_TIMEOUT: TransportParameterId = 0x01; -pub const STATELESS_RESET_TOKEN: TransportParameterId = 0x02; -pub const MAX_UDP_PAYLOAD_SIZE: TransportParameterId = 0x03; -pub const INITIAL_MAX_DATA: TransportParameterId = 0x04; -pub const INITIAL_MAX_STREAM_DATA_BIDI_LOCAL: TransportParameterId = 0x05; -pub const INITIAL_MAX_STREAM_DATA_BIDI_REMOTE: TransportParameterId = 0x06; -pub const INITIAL_MAX_STREAM_DATA_UNI: TransportParameterId = 0x07; -pub const INITIAL_MAX_STREAMS_BIDI: TransportParameterId = 0x08; -pub const INITIAL_MAX_STREAMS_UNI: TransportParameterId = 0x09; -pub const ACK_DELAY_EXPONENT: TransportParameterId = 0x0a; -pub const MAX_ACK_DELAY: TransportParameterId = 0x0b; -pub const DISABLE_MIGRATION: TransportParameterId = 0x0c; -pub const PREFERRED_ADDRESS: TransportParameterId = 0x0d; -pub const ACTIVE_CONNECTION_ID_LIMIT: TransportParameterId = 0x0e; -pub const INITIAL_SOURCE_CONNECTION_ID: TransportParameterId = 0x0f; -pub const RETRY_SOURCE_CONNECTION_ID: TransportParameterId = 0x10; -pub const VERSION_INFORMATION: TransportParameterId = 0x11; -pub const GREASE_QUIC_BIT: TransportParameterId = 0x2ab2; -pub const MIN_ACK_DELAY: TransportParameterId = 0xff02_de1a; -pub const MAX_DATAGRAM_FRAME_SIZE: TransportParameterId = 0x0020; +#[derive(Debug, Clone, Enum, PartialEq, Eq, Copy)] +#[repr(u64)] +pub enum TransportParameterId { + OriginalDestinationConnectionId = 0x00, + IdleTimeout = 0x01, + StatelessResetToken = 0x02, + MaxUdpPayloadSize = 0x03, + InitialMaxData = 0x04, + InitialMaxStreamDataBidiLocal = 0x05, + InitialMaxStreamDataBidiRemote = 0x06, + InitialMaxStreamDataUni = 0x07, + InitialMaxStreamsBidi = 0x08, + InitialMaxStreamsUni = 0x09, + AckDelayExponent = 0x0a, + MaxAckDelay = 0x0b, + DisableMigration = 0x0c, + PreferredAddress = 0x0d, + ActiveConnectionIdLimit = 0x0e, + InitialSourceConnectionId = 0x0f, + RetrySourceConnectionId = 0x10, + VersionInformation = 0x11, + GreaseQuicBit = 0x2ab2, + MinAckDelay = 0xff02_de1a, + MaxDatagramFrameSize = 0x0020, + #[cfg(test)] + TestTransportParameter = 0xce16, +} + +impl Display for TransportParameterId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + format!("{self:?}((0x{:02x}))", u64::from(*self)).fmt(f) + } +} + +impl From for u64 { + fn from(val: TransportParameterId) -> Self { + val as Self + } +} + +impl TryFrom for TransportParameterId { + type Error = Error; + + fn try_from(value: u64) -> Result { + match value { + 0x00 => Ok(Self::OriginalDestinationConnectionId), + 0x01 => Ok(Self::IdleTimeout), + 0x02 => Ok(Self::StatelessResetToken), + 0x03 => Ok(Self::MaxUdpPayloadSize), + 0x04 => Ok(Self::InitialMaxData), + 0x05 => Ok(Self::InitialMaxStreamDataBidiLocal), + 0x06 => Ok(Self::InitialMaxStreamDataBidiRemote), + 0x07 => Ok(Self::InitialMaxStreamDataUni), + 0x08 => Ok(Self::InitialMaxStreamsBidi), + 0x09 => Ok(Self::InitialMaxStreamsUni), + 0x0a => Ok(Self::AckDelayExponent), + 0x0b => Ok(Self::MaxAckDelay), + 0x0c => Ok(Self::DisableMigration), + 0x0d => Ok(Self::PreferredAddress), + 0x0e => Ok(Self::ActiveConnectionIdLimit), + 0x0f => Ok(Self::InitialSourceConnectionId), + 0x10 => Ok(Self::RetrySourceConnectionId), + 0x11 => Ok(Self::VersionInformation), + 0x2ab2 => Ok(Self::GreaseQuicBit), + 0xff02_de1a => Ok(Self::MinAckDelay), + 0x0020 => Ok(Self::MaxDatagramFrameSize), + #[cfg(test)] + 0xce16 => Ok(Self::TestTransportParameter), + _ => Err(Error::UnknownTransportParameter), + } + } +} #[derive(Clone, Debug)] pub struct PreferredAddress { @@ -128,7 +187,7 @@ pub enum TransportParameter { impl TransportParameter { fn encode(&self, enc: &mut Encoder, tp: TransportParameterId) { - qtrace!("TP encoded; type 0x{tp:02x} val {self:?}"); + qtrace!("TP encoded; type {tp}) val {self:?}"); enc.encode_varint(tp); match self { Self::Bytes(a) => { @@ -240,83 +299,82 @@ impl TransportParameter { let tp = dec.decode_varint().ok_or(Error::NoMoreData)?; let content = dec.decode_vvec().ok_or(Error::NoMoreData)?; qtrace!("TP {tp:x} length {:x}", content.len()); + let tp = match tp.try_into() { + Ok(tp) => tp, + Err(Error::UnknownTransportParameter) => return Ok(None), // Skip + Err(e) => return Err(e), + }; let mut d = Decoder::from(content); let value = match tp { - ORIGINAL_DESTINATION_CONNECTION_ID - | INITIAL_SOURCE_CONNECTION_ID - | RETRY_SOURCE_CONNECTION_ID => Self::Bytes(d.decode_remainder().to_vec()), - STATELESS_RESET_TOKEN => { + OriginalDestinationConnectionId + | InitialSourceConnectionId + | RetrySourceConnectionId => Self::Bytes(d.decode_remainder().to_vec()), + StatelessResetToken => { if d.remaining() != 16 { return Err(Error::TransportParameterError); } Self::Bytes(d.decode_remainder().to_vec()) } - IDLE_TIMEOUT - | INITIAL_MAX_DATA - | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL - | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE - | INITIAL_MAX_STREAM_DATA_UNI - | MAX_ACK_DELAY - | MAX_DATAGRAM_FRAME_SIZE => match d.decode_varint() { + IdleTimeout + | InitialMaxData + | InitialMaxStreamDataBidiLocal + | InitialMaxStreamDataBidiRemote + | InitialMaxStreamDataUni + | MaxAckDelay + | MaxDatagramFrameSize => match d.decode_varint() { Some(v) => Self::Integer(v), None => return Err(Error::TransportParameterError), }, - - INITIAL_MAX_STREAMS_BIDI | INITIAL_MAX_STREAMS_UNI => match d.decode_varint() { + InitialMaxStreamsBidi | InitialMaxStreamsUni => match d.decode_varint() { Some(v) if v <= (1 << 60) => Self::Integer(v), _ => return Err(Error::StreamLimitError), }, - - MAX_UDP_PAYLOAD_SIZE => match d.decode_varint() { + MaxUdpPayloadSize => match d.decode_varint() { Some(v) if v >= MIN_INITIAL_PACKET_SIZE.try_into()? => Self::Integer(v), _ => return Err(Error::TransportParameterError), }, - - ACK_DELAY_EXPONENT => match d.decode_varint() { + AckDelayExponent => match d.decode_varint() { Some(v) if v <= 20 => Self::Integer(v), _ => return Err(Error::TransportParameterError), }, - ACTIVE_CONNECTION_ID_LIMIT => match d.decode_varint() { + ActiveConnectionIdLimit => match d.decode_varint() { Some(v) if v >= 2 => Self::Integer(v), _ => return Err(Error::TransportParameterError), }, - - DISABLE_MIGRATION | GREASE_QUIC_BIT => Self::Empty, - - PREFERRED_ADDRESS => Self::decode_preferred_address(&mut d)?, - - MIN_ACK_DELAY => match d.decode_varint() { + DisableMigration | GreaseQuicBit => Self::Empty, + TransportParameterId::PreferredAddress => Self::decode_preferred_address(&mut d)?, + MinAckDelay => match d.decode_varint() { Some(v) if v < (1 << 24) => Self::Integer(v), _ => return Err(Error::TransportParameterError), }, - - VERSION_INFORMATION => Self::decode_versions(&mut d)?, - - // Skip. - _ => return Ok(None), + VersionInformation => Self::decode_versions(&mut d)?, + #[cfg(test)] + TransportParameterId::TestTransportParameter => { + Self::Bytes(d.decode_remainder().to_vec()) + } }; if d.remaining() > 0 { return Err(Error::TooMuchData); } - qtrace!("TP decoded; type 0x{tp:02x} val {value:?}"); + qtrace!("TP decoded; type {tp} val {value:?}"); Ok(Some((tp, value))) } } #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct TransportParameters { - params: HashMap, + params: EnumMap>, } impl TransportParameters { /// Set a value. pub fn set(&mut self, k: TransportParameterId, v: TransportParameter) { - self.params.insert(k, v); + self.params[k] = Some(v); } /// Clear a key. pub fn remove(&mut self, k: TransportParameterId) { - self.params.remove(&k); + self.params[k].take(); } /// Decode is a static function that parses transport parameters @@ -339,7 +397,9 @@ impl TransportParameters { pub(crate) fn encode(&self, enc: &mut Encoder) { for (tipe, tp) in &self.params { - tp.encode(enc, *tipe); + if let Some(tp) = tp { + tp.encode(enc, tipe); + } } } @@ -349,24 +409,24 @@ impl TransportParameters { #[must_use] pub fn get_integer(&self, tp: TransportParameterId) -> u64 { let default = match tp { - IDLE_TIMEOUT - | INITIAL_MAX_DATA - | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL - | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE - | INITIAL_MAX_STREAM_DATA_UNI - | INITIAL_MAX_STREAMS_BIDI - | INITIAL_MAX_STREAMS_UNI - | MIN_ACK_DELAY - | MAX_DATAGRAM_FRAME_SIZE => 0, - MAX_UDP_PAYLOAD_SIZE => 65527, - ACK_DELAY_EXPONENT => 3, - MAX_ACK_DELAY => 25, - ACTIVE_CONNECTION_ID_LIMIT => 2, + IdleTimeout + | InitialMaxData + | InitialMaxStreamDataBidiLocal + | InitialMaxStreamDataBidiRemote + | InitialMaxStreamDataUni + | InitialMaxStreamsBidi + | InitialMaxStreamsUni + | MinAckDelay + | MaxDatagramFrameSize => 0, + MaxUdpPayloadSize => 65527, + AckDelayExponent => 3, + MaxAckDelay => 25, + ActiveConnectionIdLimit => 2, _ => panic!("Transport parameter not known or not an Integer"), }; - match self.params.get(&tp) { + match self.params[tp] { None => default, - Some(TransportParameter::Integer(x)) => *x, + Some(TransportParameter::Integer(x)) => x, _ => panic!("Internal error"), } } @@ -376,19 +436,19 @@ impl TransportParameters { /// When the transport parameter isn't recognized as being an integer. pub fn set_integer(&mut self, tp: TransportParameterId, value: u64) { match tp { - IDLE_TIMEOUT - | INITIAL_MAX_DATA - | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL - | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE - | INITIAL_MAX_STREAM_DATA_UNI - | INITIAL_MAX_STREAMS_BIDI - | INITIAL_MAX_STREAMS_UNI - | MAX_UDP_PAYLOAD_SIZE - | ACK_DELAY_EXPONENT - | MAX_ACK_DELAY - | ACTIVE_CONNECTION_ID_LIMIT - | MIN_ACK_DELAY - | MAX_DATAGRAM_FRAME_SIZE => { + IdleTimeout + | InitialMaxData + | InitialMaxStreamDataBidiLocal + | InitialMaxStreamDataBidiRemote + | InitialMaxStreamDataUni + | InitialMaxStreamsBidi + | InitialMaxStreamsUni + | MaxUdpPayloadSize + | AckDelayExponent + | MaxAckDelay + | ActiveConnectionIdLimit + | MinAckDelay + | MaxDatagramFrameSize => { self.set(tp, TransportParameter::Integer(value)); } _ => panic!("Transport parameter not known"), @@ -400,14 +460,14 @@ impl TransportParameters { #[must_use] pub fn get_bytes(&self, tp: TransportParameterId) -> Option<&[u8]> { match tp { - ORIGINAL_DESTINATION_CONNECTION_ID - | INITIAL_SOURCE_CONNECTION_ID - | RETRY_SOURCE_CONNECTION_ID - | STATELESS_RESET_TOKEN => {} + OriginalDestinationConnectionId + | InitialSourceConnectionId + | RetrySourceConnectionId + | StatelessResetToken => {} _ => panic!("Transport parameter not known or not type bytes"), } - match self.params.get(&tp) { + match &self.params[tp] { None => None, Some(TransportParameter::Bytes(x)) => Some(x), _ => panic!("Internal error"), @@ -418,10 +478,10 @@ impl TransportParameters { /// When the transport parameter isn't recognized as containing bytes. pub fn set_bytes(&mut self, tp: TransportParameterId, value: Vec) { match tp { - ORIGINAL_DESTINATION_CONNECTION_ID - | INITIAL_SOURCE_CONNECTION_ID - | RETRY_SOURCE_CONNECTION_ID - | STATELESS_RESET_TOKEN => { + OriginalDestinationConnectionId + | InitialSourceConnectionId + | RetrySourceConnectionId + | StatelessResetToken => { self.set(tp, TransportParameter::Bytes(value)); } _ => panic!("Transport parameter not known or not type bytes"), @@ -432,7 +492,7 @@ impl TransportParameters { /// When the transport parameter isn't recognized as being empty. pub fn set_empty(&mut self, tp: TransportParameterId) { match tp { - DISABLE_MIGRATION | GREASE_QUIC_BIT => { + DisableMigration | GreaseQuicBit => { self.set(tp, TransportParameter::Empty); } _ => panic!("Transport parameter not known or not type empty"), @@ -452,7 +512,7 @@ impl TransportParameters { } let current = versions.initial().wire_version(); self.set( - VERSION_INFORMATION, + VersionInformation, TransportParameter::Versions { current, other }, ); } @@ -460,7 +520,7 @@ impl TransportParameters { fn compatible_upgrade(&mut self, v: Version) { if let Some(TransportParameter::Versions { ref mut current, .. - }) = self.params.get_mut(&VERSION_INFORMATION) + }) = self.params[VersionInformation] { *current = v.wire_version(); } else { @@ -473,7 +533,7 @@ impl TransportParameters { /// This should not happen if the parsing code in `TransportParameter::decode` is correct. #[must_use] pub fn get_empty(&self, tipe: TransportParameterId) -> bool { - match self.params.get(&tipe) { + match self.params[tipe] { None => false, Some(TransportParameter::Empty) => true, _ => panic!("Internal error"), @@ -486,26 +546,31 @@ impl TransportParameters { pub(crate) fn ok_for_0rtt(&self, remembered: &Self) -> bool { for (k, v_rem) in &remembered.params { // Skip checks for these, which don't affect 0-RTT. - if matches!( - *k, - ORIGINAL_DESTINATION_CONNECTION_ID - | INITIAL_SOURCE_CONNECTION_ID - | RETRY_SOURCE_CONNECTION_ID - | STATELESS_RESET_TOKEN - | IDLE_TIMEOUT - | ACK_DELAY_EXPONENT - | MAX_ACK_DELAY - | ACTIVE_CONNECTION_ID_LIMIT - | PREFERRED_ADDRESS - ) { + if v_rem.is_none() + || matches!( + k, + OriginalDestinationConnectionId + | InitialSourceConnectionId + | RetrySourceConnectionId + | StatelessResetToken + | IdleTimeout + | AckDelayExponent + | MaxAckDelay + | ActiveConnectionIdLimit + | TransportParameterId::PreferredAddress + ) + { continue; } - let ok = self - .params - .get(k) + + let ok = self.params[k] + .as_ref() .is_some_and(|v_self| match (v_self, v_rem) { - (TransportParameter::Integer(i_self), TransportParameter::Integer(i_rem)) => { - if *k == MIN_ACK_DELAY { + ( + TransportParameter::Integer(i_self), + Some(TransportParameter::Integer(i_rem)), + ) => { + if k == MinAckDelay { // MIN_ACK_DELAY is backwards: // it can only be reduced safely. *i_self <= *i_rem @@ -513,12 +578,12 @@ impl TransportParameters { *i_self >= *i_rem } } - (TransportParameter::Empty, TransportParameter::Empty) => true, + (TransportParameter::Empty, Some(TransportParameter::Empty)) => true, ( TransportParameter::Versions { current: v_self, .. }, - TransportParameter::Versions { current: v_rem, .. }, + Some(TransportParameter::Versions { current: v_rem, .. }), ) => v_self == v_rem, _ => false, }); @@ -533,7 +598,7 @@ impl TransportParameters { #[must_use] pub fn get_preferred_address(&self) -> Option<(PreferredAddress, ConnectionIdEntry<[u8; 16]>)> { if let Some(TransportParameter::PreferredAddress { v4, v6, cid, srt }) = - self.params.get(&PREFERRED_ADDRESS) + &self.params[TransportParameterId::PreferredAddress] { Some(( PreferredAddress::new(*v4, *v6), @@ -548,7 +613,7 @@ impl TransportParameters { #[must_use] pub fn get_versions(&self) -> Option<(WireVersion, &[WireVersion])> { if let Some(TransportParameter::Versions { current, other }) = - self.params.get(&VERSION_INFORMATION) + &self.params[VersionInformation] { Some((*current, other)) } else { @@ -558,7 +623,7 @@ impl TransportParameters { #[must_use] pub fn has_value(&self, tp: TransportParameterId) -> bool { - self.params.contains_key(&tp) + self.params[tp].is_some() } } @@ -755,20 +820,12 @@ where mod tests { use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; - use neqo_common::{Decoder, Encoder}; + use neqo_common::{qdebug, Decoder, Encoder}; + use TransportParameterId::*; use super::PreferredAddress; use crate::{ - tparams::{ - TransportParameter, TransportParameterId, TransportParameters, - ACTIVE_CONNECTION_ID_LIMIT, IDLE_TIMEOUT, INITIAL_MAX_DATA, INITIAL_MAX_STREAMS_BIDI, - INITIAL_MAX_STREAMS_UNI, INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, - INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, INITIAL_MAX_STREAM_DATA_UNI, - INITIAL_SOURCE_CONNECTION_ID, MAX_ACK_DELAY, MAX_DATAGRAM_FRAME_SIZE, - MAX_UDP_PAYLOAD_SIZE, MIN_ACK_DELAY, ORIGINAL_DESTINATION_CONNECTION_ID, - PREFERRED_ADDRESS, RETRY_SOURCE_CONNECTION_ID, STATELESS_RESET_TOKEN, - VERSION_INFORMATION, - }, + tparams::{TransportParameter, TransportParameterId, TransportParameters}, ConnectionId, Error, Version, }; @@ -777,11 +834,10 @@ mod tests { const RESET_TOKEN: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8]; let mut tps = TransportParameters::default(); tps.set( - STATELESS_RESET_TOKEN, + StatelessResetToken, TransportParameter::Bytes(RESET_TOKEN.to_vec()), ); - tps.params - .insert(INITIAL_MAX_STREAMS_BIDI, TransportParameter::Integer(10)); + tps.params[InitialMaxStreamsBidi] = Some(TransportParameter::Integer(10)); let mut enc = Encoder::default(); tps.encode(&mut enc); @@ -790,18 +846,18 @@ mod tests { assert_eq!(tps, tps2); println!("TPS = {tps:?}"); - assert_eq!(tps2.get_integer(IDLE_TIMEOUT), 0); // Default - assert_eq!(tps2.get_integer(MAX_ACK_DELAY), 25); // Default - assert_eq!(tps2.get_integer(ACTIVE_CONNECTION_ID_LIMIT), 2); // Default - assert_eq!(tps2.get_integer(INITIAL_MAX_STREAMS_BIDI), 10); // Sent - assert_eq!(tps2.get_bytes(STATELESS_RESET_TOKEN), Some(RESET_TOKEN)); - assert_eq!(tps2.get_bytes(ORIGINAL_DESTINATION_CONNECTION_ID), None); - assert_eq!(tps2.get_bytes(INITIAL_SOURCE_CONNECTION_ID), None); - assert_eq!(tps2.get_bytes(RETRY_SOURCE_CONNECTION_ID), None); - assert!(!tps2.has_value(ORIGINAL_DESTINATION_CONNECTION_ID)); - assert!(!tps2.has_value(INITIAL_SOURCE_CONNECTION_ID)); - assert!(!tps2.has_value(RETRY_SOURCE_CONNECTION_ID)); - assert!(tps2.has_value(STATELESS_RESET_TOKEN)); + assert_eq!(tps2.get_integer(IdleTimeout), 0); // Default + assert_eq!(tps2.get_integer(MaxAckDelay), 25); // Default + assert_eq!(tps2.get_integer(ActiveConnectionIdLimit), 2); // Default + assert_eq!(tps2.get_integer(InitialMaxStreamsBidi), 10); // Sent + assert_eq!(tps2.get_bytes(StatelessResetToken), Some(RESET_TOKEN)); + assert_eq!(tps2.get_bytes(OriginalDestinationConnectionId), None); + assert_eq!(tps2.get_bytes(InitialSourceConnectionId), None); + assert_eq!(tps2.get_bytes(RetrySourceConnectionId), None); + assert!(!tps2.has_value(OriginalDestinationConnectionId)); + assert!(!tps2.has_value(InitialSourceConnectionId)); + assert!(!tps2.has_value(RetrySourceConnectionId)); + assert!(tps2.has_value(StatelessResetToken)); let mut enc = Encoder::default(); tps.encode(&mut enc); @@ -833,12 +889,12 @@ mod tests { ]; let spa = make_spa(); let mut enc = Encoder::new(); - spa.encode(&mut enc, PREFERRED_ADDRESS); + spa.encode(&mut enc, PreferredAddress); assert_eq!(enc.as_ref(), ENCODED); let mut dec = enc.as_decoder(); let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap(); - assert_eq!(id, PREFERRED_ADDRESS); + assert_eq!(id, PreferredAddress); assert_eq!(decoded, spa); } @@ -866,7 +922,7 @@ mod tests { /// doesn't care about validity, and decodes it. The result should be failure. fn assert_invalid_spa(spa: &TransportParameter) { let mut enc = Encoder::new(); - spa.encode(&mut enc, PREFERRED_ADDRESS); + spa.encode(&mut enc, PreferredAddress); assert_eq!( TransportParameter::decode(&mut enc.as_decoder()).unwrap_err(), Error::TransportParameterError @@ -876,10 +932,10 @@ mod tests { /// This is for those rare mutations that are acceptable. fn assert_valid_spa(spa: &TransportParameter) { let mut enc = Encoder::new(); - spa.encode(&mut enc, PREFERRED_ADDRESS); + spa.encode(&mut enc, PreferredAddress); let mut dec = enc.as_decoder(); let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap(); - assert_eq!(id, PREFERRED_ADDRESS); + assert_eq!(id, PreferredAddress); assert_eq!(&decoded, spa); } @@ -927,7 +983,7 @@ mod tests { fn preferred_address_truncated() { let spa = make_spa(); let mut enc = Encoder::new(); - spa.encode(&mut enc, PREFERRED_ADDRESS); + spa.encode(&mut enc, PreferredAddress); let mut dec = Decoder::from(&enc.as_ref()[..enc.len() - 1]); assert_eq!( TransportParameter::decode(&mut dec).unwrap_err(), @@ -972,24 +1028,24 @@ mod tests { fn compatible_0rtt_ignored_values() { let mut tps_a = TransportParameters::default(); tps_a.set( - STATELESS_RESET_TOKEN, + StatelessResetToken, TransportParameter::Bytes(vec![1, 2, 3]), ); - tps_a.set(IDLE_TIMEOUT, TransportParameter::Integer(10)); - tps_a.set(MAX_ACK_DELAY, TransportParameter::Integer(22)); - tps_a.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(33)); + tps_a.set(IdleTimeout, TransportParameter::Integer(10)); + tps_a.set(MaxAckDelay, TransportParameter::Integer(22)); + tps_a.set(ActiveConnectionIdLimit, TransportParameter::Integer(33)); let mut tps_b = TransportParameters::default(); assert!(tps_a.ok_for_0rtt(&tps_b)); assert!(tps_b.ok_for_0rtt(&tps_a)); tps_b.set( - STATELESS_RESET_TOKEN, + StatelessResetToken, TransportParameter::Bytes(vec![8, 9, 10]), ); - tps_b.set(IDLE_TIMEOUT, TransportParameter::Integer(100)); - tps_b.set(MAX_ACK_DELAY, TransportParameter::Integer(2)); - tps_b.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(44)); + tps_b.set(IdleTimeout, TransportParameter::Integer(100)); + tps_b.set(MaxAckDelay, TransportParameter::Integer(2)); + tps_b.set(ActiveConnectionIdLimit, TransportParameter::Integer(44)); assert!(tps_a.ok_for_0rtt(&tps_b)); assert!(tps_b.ok_for_0rtt(&tps_a)); } @@ -997,15 +1053,15 @@ mod tests { #[test] fn compatible_0rtt_integers() { const INTEGER_KEYS: &[TransportParameterId] = &[ - INITIAL_MAX_DATA, - INITIAL_MAX_STREAM_DATA_BIDI_LOCAL, - INITIAL_MAX_STREAM_DATA_BIDI_REMOTE, - INITIAL_MAX_STREAM_DATA_UNI, - INITIAL_MAX_STREAMS_BIDI, - INITIAL_MAX_STREAMS_UNI, - MAX_UDP_PAYLOAD_SIZE, - MIN_ACK_DELAY, - MAX_DATAGRAM_FRAME_SIZE, + InitialMaxData, + InitialMaxStreamDataBidiLocal, + InitialMaxStreamDataBidiRemote, + InitialMaxStreamDataUni, + InitialMaxStreamsBidi, + InitialMaxStreamsUni, + MaxUdpPayloadSize, + MinAckDelay, + MaxDatagramFrameSize, ]; let mut tps_a = TransportParameters::default(); @@ -1021,7 +1077,7 @@ mod tests { for i in INTEGER_KEYS { let mut tps_b = tps_a.clone(); // Set a safe new value; reducing MIN_ACK_DELAY instead. - let safe_value = if *i == MIN_ACK_DELAY { 11 } else { 13 }; + let safe_value = if *i == MinAckDelay { 11 } else { 13 }; tps_b.set(*i, TransportParameter::Integer(safe_value)); // If the new value is not safe relative to the remembered value, // then we can't attempt 0-RTT with these parameters. @@ -1048,8 +1104,7 @@ mod tests { // Intentionally set an invalid value for the ACTIVE_CONNECTION_ID_LIMIT transport // parameter. - tps.params - .insert(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(1)); + tps.params[ActiveConnectionIdLimit] = Some(TransportParameter::Integer(1)); let mut enc = Encoder::default(); tps.encode(&mut enc); @@ -1071,12 +1126,12 @@ mod tests { }; let mut enc = Encoder::new(); - vn.encode(&mut enc, VERSION_INFORMATION); + vn.encode(&mut enc, VersionInformation); assert_eq!(enc.as_ref(), ENCODED); let mut dec = enc.as_decoder(); let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap(); - assert_eq!(id, VERSION_INFORMATION); + assert_eq!(id, VersionInformation); assert_eq!(decoded, vn); } @@ -1113,8 +1168,9 @@ mod tests { #[test] fn versions_equal_0rtt() { let mut current = TransportParameters::default(); + qdebug!("Current = {:?}", current); current.set( - VERSION_INFORMATION, + VersionInformation, TransportParameter::Versions { current: Version::Version1.wire_version(), other: vec![0x1a2a_3a4a], @@ -1129,7 +1185,7 @@ mod tests { // If the version matches, it's OK to use 0-RTT. remembered.set( - VERSION_INFORMATION, + VersionInformation, TransportParameter::Versions { current: Version::Version1.wire_version(), other: vec![0x5a6a_7a8a, 0x9aaa_baca], @@ -1140,7 +1196,7 @@ mod tests { // An apparent "upgrade" is still cause to reject 0-RTT. remembered.set( - VERSION_INFORMATION, + VersionInformation, TransportParameter::Versions { current: Version::Version1.wire_version() + 1, other: vec![], diff --git a/neqo-transport/src/version.rs b/neqo-transport/src/version.rs index cf2a5a74cb..8d6066951d 100644 --- a/neqo-transport/src/version.rs +++ b/neqo-transport/src/version.rs @@ -6,13 +6,14 @@ #![allow(clippy::module_name_repetitions)] +use enum_map::Enum; use neqo_common::qdebug; use crate::{Error, Res}; pub type WireVersion = u32; -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Enum)] pub enum Version { Version2, #[default]