From 269127daa5aac10799a5c1fa740d8f41fa68a74d Mon Sep 17 00:00:00 2001 From: minghuaw Date: Sun, 22 Dec 2024 19:18:36 -0800 Subject: [PATCH] experiment: Improve serialize performance (#297) * pre-alloc buffer for seq * fix some elidible lifetimes * suppress one clippy warning * add more bench cases (#298) * pre-alloc buffer for seq * fix some elidible lifetimes * suppress one clippy warning * fix elidible lifetime * cargo fmt * fix elidible lifetime * remove random lengh elements * pre-alloc for MapSerializer * cargo fmt * bump version and updated changelog --- fe2o3-amqp-cbs/src/put_token.rs | 2 +- .../src/operations/entity/create.rs | 2 +- .../src/operations/entity/delete.rs | 2 +- .../src/operations/entity/read.rs | 2 +- .../src/operations/entity/update.rs | 2 +- .../src/operations/node/get_annotations.rs | 2 +- .../src/operations/node/get_attributes.rs | 2 +- .../src/operations/node/get_mgmt_nodes.rs | 2 +- .../src/operations/node/get_operations.rs | 2 +- .../src/operations/node/get_types.rs | 2 +- .../src/operations/node/query.rs | 2 +- .../src/operations/node/register.rs | 2 +- fe2o3-amqp/src/connection/builder.rs | 9 +- fe2o3-amqp/src/util/mod.rs | 6 +- serde_amqp/Cargo.toml | 2 +- serde_amqp/Changelog.md | 6 + serde_amqp/benches/serialize.rs | 84 ------- serde_amqp/src/ser.rs | 213 ++++++++++++++++-- 18 files changed, 221 insertions(+), 123 deletions(-) diff --git a/fe2o3-amqp-cbs/src/put_token.rs b/fe2o3-amqp-cbs/src/put_token.rs index 1b4690ff..6776814b 100644 --- a/fe2o3-amqp-cbs/src/put_token.rs +++ b/fe2o3-amqp-cbs/src/put_token.rs @@ -50,7 +50,7 @@ impl<'a> PutTokenRequest<'a> { } } -impl<'a> Request for PutTokenRequest<'a> { +impl Request for PutTokenRequest<'_> { const OPERATION: &'static str = PUT_TOKEN; type Response = PutTokenResponse; diff --git a/fe2o3-amqp-management/src/operations/entity/create.rs b/fe2o3-amqp-management/src/operations/entity/create.rs index 3687a68b..b68d65f3 100644 --- a/fe2o3-amqp-management/src/operations/entity/create.rs +++ b/fe2o3-amqp-management/src/operations/entity/create.rs @@ -84,7 +84,7 @@ impl<'a> CreateRequest<'a> { } } -impl<'a> Request for CreateRequest<'a> { +impl Request for CreateRequest<'_> { const OPERATION: &'static str = CREATE; type Response = CreateResponse; diff --git a/fe2o3-amqp-management/src/operations/entity/delete.rs b/fe2o3-amqp-management/src/operations/entity/delete.rs index 5fa25835..27308a4c 100644 --- a/fe2o3-amqp-management/src/operations/entity/delete.rs +++ b/fe2o3-amqp-management/src/operations/entity/delete.rs @@ -79,7 +79,7 @@ impl<'a> DeleteRequest<'a> { } } -impl<'a> Request for DeleteRequest<'a> { +impl Request for DeleteRequest<'_> { const OPERATION: &'static str = DELETE; type Response = DeleteResponse; diff --git a/fe2o3-amqp-management/src/operations/entity/read.rs b/fe2o3-amqp-management/src/operations/entity/read.rs index ff5afc37..f9e2d88a 100644 --- a/fe2o3-amqp-management/src/operations/entity/read.rs +++ b/fe2o3-amqp-management/src/operations/entity/read.rs @@ -75,7 +75,7 @@ impl<'a> ReadRequest<'a> { } } -impl<'a> Request for ReadRequest<'a> { +impl Request for ReadRequest<'_> { const OPERATION: &'static str = READ; type Response = ReadResponse; diff --git a/fe2o3-amqp-management/src/operations/entity/update.rs b/fe2o3-amqp-management/src/operations/entity/update.rs index a60c9022..63a197aa 100644 --- a/fe2o3-amqp-management/src/operations/entity/update.rs +++ b/fe2o3-amqp-management/src/operations/entity/update.rs @@ -94,7 +94,7 @@ impl<'a> UpdateRequest<'a> { } } -impl<'a> Request for UpdateRequest<'a> { +impl Request for UpdateRequest<'_> { const OPERATION: &'static str = UPDATE; type Response = UpdateResponse; diff --git a/fe2o3-amqp-management/src/operations/node/get_annotations.rs b/fe2o3-amqp-management/src/operations/node/get_annotations.rs index 3458228e..713aa908 100644 --- a/fe2o3-amqp-management/src/operations/node/get_annotations.rs +++ b/fe2o3-amqp-management/src/operations/node/get_annotations.rs @@ -35,7 +35,7 @@ impl<'a> GetAnnotationsRequest<'a> { } } -impl<'a> Request for GetAnnotationsRequest<'a> { +impl Request for GetAnnotationsRequest<'_> { const OPERATION: &'static str = GET_ANNOTATIONS; type Response = GetAnnotationsResponse; diff --git a/fe2o3-amqp-management/src/operations/node/get_attributes.rs b/fe2o3-amqp-management/src/operations/node/get_attributes.rs index aa2ea11e..c59920f7 100644 --- a/fe2o3-amqp-management/src/operations/node/get_attributes.rs +++ b/fe2o3-amqp-management/src/operations/node/get_attributes.rs @@ -35,7 +35,7 @@ impl<'a> GetAttributesRequest<'a> { } } -impl<'a> Request for GetAttributesRequest<'a> { +impl Request for GetAttributesRequest<'_> { const OPERATION: &'static str = GET_ATTRIBUTES; type Response = GetAttributesResponse; diff --git a/fe2o3-amqp-management/src/operations/node/get_mgmt_nodes.rs b/fe2o3-amqp-management/src/operations/node/get_mgmt_nodes.rs index 6fa59a85..b07d0d91 100644 --- a/fe2o3-amqp-management/src/operations/node/get_mgmt_nodes.rs +++ b/fe2o3-amqp-management/src/operations/node/get_mgmt_nodes.rs @@ -33,7 +33,7 @@ impl<'a> GetMgmtNodesRequest<'a> { } } -impl<'a> Request for GetMgmtNodesRequest<'a> { +impl Request for GetMgmtNodesRequest<'_> { const OPERATION: &'static str = GET_MGMT_NODES; type Response = GetMgmtNodesResponse; diff --git a/fe2o3-amqp-management/src/operations/node/get_operations.rs b/fe2o3-amqp-management/src/operations/node/get_operations.rs index 784f28ce..e8d0b1c8 100644 --- a/fe2o3-amqp-management/src/operations/node/get_operations.rs +++ b/fe2o3-amqp-management/src/operations/node/get_operations.rs @@ -38,7 +38,7 @@ impl<'a> GetOperationsRequest<'a> { } } -impl<'a> Request for GetOperationsRequest<'a> { +impl Request for GetOperationsRequest<'_> { const OPERATION: &'static str = GET_OPERATIONS; type Response = GetOperationsResponse; diff --git a/fe2o3-amqp-management/src/operations/node/get_types.rs b/fe2o3-amqp-management/src/operations/node/get_types.rs index 93e605b6..91012cc7 100644 --- a/fe2o3-amqp-management/src/operations/node/get_types.rs +++ b/fe2o3-amqp-management/src/operations/node/get_types.rs @@ -35,7 +35,7 @@ impl<'a> GetTypesRequest<'a> { } } -impl<'a> Request for GetTypesRequest<'a> { +impl Request for GetTypesRequest<'_> { const OPERATION: &'static str = GET_TYPES; type Response = GetTypesResponse; diff --git a/fe2o3-amqp-management/src/operations/node/query.rs b/fe2o3-amqp-management/src/operations/node/query.rs index c5e2fa69..3299bc0d 100644 --- a/fe2o3-amqp-management/src/operations/node/query.rs +++ b/fe2o3-amqp-management/src/operations/node/query.rs @@ -82,7 +82,7 @@ impl<'a> QueryRequest<'a> { } } -impl<'a> Request for QueryRequest<'a> { +impl Request for QueryRequest<'_> { const OPERATION: &'static str = QUERY; type Response = QueryResponse; diff --git a/fe2o3-amqp-management/src/operations/node/register.rs b/fe2o3-amqp-management/src/operations/node/register.rs index 78c12632..d59471ba 100644 --- a/fe2o3-amqp-management/src/operations/node/register.rs +++ b/fe2o3-amqp-management/src/operations/node/register.rs @@ -47,7 +47,7 @@ impl<'a> RegisterRequest<'a> { } } -impl<'a> Request for RegisterRequest<'a> { +impl Request for RegisterRequest<'_> { const OPERATION: &'static str = REGISTER; type Response = RegisterResponse; diff --git a/fe2o3-amqp/src/connection/builder.rs b/fe2o3-amqp/src/connection/builder.rs index 4f450daa..4b55bf6d 100644 --- a/fe2o3-amqp/src/connection/builder.rs +++ b/fe2o3-amqp/src/connection/builder.rs @@ -174,7 +174,7 @@ impl<'a, Tls> From> for Open { } } -impl<'a, Mode: std::fmt::Debug> std::fmt::Debug for Builder<'a, Mode, ()> { +impl std::fmt::Debug for Builder<'_, Mode, ()> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Builder") .field("container_id", &self.container_id) @@ -254,13 +254,13 @@ cfg_not_wasm32! { } } -impl<'a, Mode> Default for Builder<'a, Mode, ()> { +impl Default for Builder<'_, Mode, ()> { fn default() -> Self { Self::new() } } -impl<'a, Mode> Builder<'a, Mode, ()> { +impl Builder<'_, Mode, ()> { /// Creates a new builder for [`crate::Connection`] pub fn new() -> Self { Self { @@ -323,6 +323,7 @@ impl<'a, Tls> Builder<'a, mode::ConnectorNoId, Tls> { } } +#[allow(clippy::needless_lifetimes)] impl<'a, Mode, Tls> Builder<'a, Mode, Tls> { /// Alias for [`rustls_connector`](#method.rustls_connector) if only `"rustls"` is enabled #[cfg_attr(docsrs, doc(cfg(all(feature = "rustls", not(feature = "native-tls")))))] @@ -753,7 +754,7 @@ impl<'a, Tls> Builder<'a, mode::ConnectorWithId, Tls> { /* Without TLS */ /* -------------------------------------------------------------------------- */ -impl<'a> Builder<'a, mode::ConnectorWithId, ()> { +impl Builder<'_, mode::ConnectorWithId, ()> { #[cfg(all(feature = "rustls", not(feature = "native-tls")))] async fn connect_tls_with_rustls_default( self, diff --git a/fe2o3-amqp/src/util/mod.rs b/fe2o3-amqp/src/util/mod.rs index 165c9a99..5f712ee9 100644 --- a/fe2o3-amqp/src/util/mod.rs +++ b/fe2o3-amqp/src/util/mod.rs @@ -180,7 +180,7 @@ impl AsByteIterator for Payload { } } -impl<'a> AsByteIterator for &'a Payload { +impl AsByteIterator for &Payload { type IterImpl<'i> = std::slice::Iter<'i, u8> where @@ -271,13 +271,13 @@ impl<'a> Iterator for ByteReaderIter<'a> { } } -impl<'a> ExactSizeIterator for ByteReaderIter<'a> { +impl ExactSizeIterator for ByteReaderIter<'_> { fn len(&self) -> usize { self.inner.iter().map(|iter| iter.len()).sum() } } -impl<'a> DoubleEndedIterator for ByteReaderIter<'a> { +impl DoubleEndedIterator for ByteReaderIter<'_> { fn next_back(&mut self) -> Option { self.inner .iter_mut() diff --git a/serde_amqp/Cargo.toml b/serde_amqp/Cargo.toml index e2637c30..ebb5b783 100644 --- a/serde_amqp/Cargo.toml +++ b/serde_amqp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "serde_amqp" -version = "0.13.1" +version = "0.13.2" edition = "2021" description = "A serde implementation of AMQP1.0 protocol." license = "MIT/Apache-2.0" diff --git a/serde_amqp/Changelog.md b/serde_amqp/Changelog.md index cdbbcb4c..099bf74f 100644 --- a/serde_amqp/Changelog.md +++ b/serde_amqp/Changelog.md @@ -1,5 +1,11 @@ # Change Log +## 0.13.2 + +1. Improve serializer performance in serializing list and map types by + pre-allocating the buffer based on the suggested length and the serialized + size of the first element/entry. + ## 0.13.1 1. Added `to_lazy_value` diff --git a/serde_amqp/benches/serialize.rs b/serde_amqp/benches/serialize.rs index 2df4538a..f3e62134 100644 --- a/serde_amqp/benches/serialize.rs +++ b/serde_amqp/benches/serialize.rs @@ -283,42 +283,6 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| serde_amqp::to_vec(black_box(&value)).unwrap()) }); - // 10 Random strings of size between 16B and 1kB - let value = (0..10) - .map(|_| { - let size = rand::thread_rng().gen_range(16..1024); - Alphanumeric.sample_string(&mut rand::thread_rng(), size) - }) - .map(String::from) - .collect::>(); - c.bench_function("serialize List 10x16B-10kB", |b| { - b.iter(|| serde_amqp::to_vec(black_box(&value)).unwrap()) - }); - - // 100 Random strings of size between 16B and 1kB - let value = (0..100) - .map(|_| { - let size = rand::thread_rng().gen_range(16..1024); - Alphanumeric.sample_string(&mut rand::thread_rng(), size) - }) - .map(String::from) - .collect::>(); - c.bench_function("serialize List 100x16B-10kB", |b| { - b.iter(|| serde_amqp::to_vec(black_box(&value)).unwrap()) - }); - - // 1000 Random strings of size between 16B and 1kB - let value = (0..1000) - .map(|_| { - let size = rand::thread_rng().gen_range(16..1024); - Alphanumeric.sample_string(&mut rand::thread_rng(), size) - }) - .map(String::from) - .collect::>(); - c.bench_function("serialize List 1000x16B-10kB", |b| { - b.iter(|| serde_amqp::to_vec(black_box(&value)).unwrap()) - }); - // Map of 10 u64 -> u64 let value = (0..10) .map(|_| { @@ -396,54 +360,6 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("serialize Map 1000x16B", |b| { b.iter(|| serde_amqp::to_vec(black_box(&value)).unwrap()) }); - - // Map of 10 random String (16B-1kB) -> random String (16B-1kB) - let value = (0..10) - .map(|_| { - let key_size = rand::thread_rng().gen_range(16..1024); - let key = Alphanumeric.sample_string(&mut rand::thread_rng(), key_size); - let key = String::from(key); - let value_size = rand::thread_rng().gen_range(16..1024); - let value = Alphanumeric.sample_string(&mut rand::thread_rng(), value_size); - let value = String::from(value); - (key, value) - }) - .collect::>(); - c.bench_function("serialize Map 10x16B-1kB", |b| { - b.iter(|| serde_amqp::to_vec(black_box(&value)).unwrap()) - }); - - // Map of 100 random String (16B-1kB) -> random String (16B-1kB) - let value = (0..100) - .map(|_| { - let key_size = rand::thread_rng().gen_range(16..1024); - let key = Alphanumeric.sample_string(&mut rand::thread_rng(), key_size); - let key = String::from(key); - let value_size = rand::thread_rng().gen_range(16..1024); - let value = Alphanumeric.sample_string(&mut rand::thread_rng(), value_size); - let value = String::from(value); - (key, value) - }) - .collect::>(); - c.bench_function("serialize Map 100x16B-1kB", |b| { - b.iter(|| serde_amqp::to_vec(black_box(&value)).unwrap()) - }); - - // Map of 1000 random String (16B-1kB) -> random String (16B-1kB) - let value = (0..1000) - .map(|_| { - let key_size = rand::thread_rng().gen_range(16..1024); - let key = Alphanumeric.sample_string(&mut rand::thread_rng(), key_size); - let key = String::from(key); - let value_size = rand::thread_rng().gen_range(16..1024); - let value = Alphanumeric.sample_string(&mut rand::thread_rng(), value_size); - let value = String::from(value); - (key, value) - }) - .collect::>(); - c.bench_function("serialize Map 1000x16B-1kB", |b| { - b.iter(|| serde_amqp::to_vec(black_box(&value)).unwrap()) - }); } criterion_group!(benches, criterion_benchmark); diff --git a/serde_amqp/src/ser.rs b/serde_amqp/src/ser.rs index cf22426d..63e8bea9 100644 --- a/serde_amqp/src/ser.rs +++ b/serde_amqp/src/ser.rs @@ -16,6 +16,7 @@ use crate::{ error::Error, format::{OFFSET_LIST32, OFFSET_LIST8, OFFSET_MAP32, OFFSET_MAP8}, format_code::EncodingCodes, + serialized_size, util::{FieldRole, IsArrayElement, NonNativeType, SequenceType, StructEncoding}, }; @@ -38,6 +39,9 @@ where Ok(writer) } +// TODO: place an optional arena (bumpalo) into the serializer, and whenever a seq serializer is +// created, the internal buffer (or all the nested serializers) will be created in the arena + /// A struct for serializing Rust structs/values into AMQP1.0 wire format #[derive(Debug)] pub struct Serializer { @@ -696,9 +700,9 @@ impl<'a, W: Write + 'a> ser::Serializer for &'a mut Serializer { // // This will be encoded as primitive type `Array` #[inline] - fn serialize_seq(self, _len: Option) -> Result { + fn serialize_seq(self, len: Option) -> Result { // The most external array should be treated as IsArrayElement::False - Ok(SeqSerializer::new(self)) + Ok(SeqSerializer::new(self, len)) } // A statically sized heterogeneous sequence of values @@ -740,8 +744,8 @@ impl<'a, W: Write + 'a> ser::Serializer for &'a mut Serializer { } #[inline] - fn serialize_map(self, _len: Option) -> Result { - Ok(MapSerializer::new(self)) + fn serialize_map(self, len: Option) -> Result { + Ok(MapSerializer::new(self, len)) } // The serde data model treats struct as "A statically sized heterogeneous key-value pairing" @@ -813,20 +817,49 @@ impl<'a, W: Write + 'a> ser::Serializer for &'a mut Serializer { } } +/// Internal state of the sequence serializer +#[derive(Debug, Clone)] +enum SeqSerializerState { + /// Initialized with the length hint + Init(Option), + + /// Buffering the serialized bytes + Buffer(Vec), +} + /// Serializer for sequence types #[derive(Debug)] pub struct SeqSerializer<'a, W: 'a> { se: &'a mut Serializer, num: usize, - buf: Vec, + state: SeqSerializerState, } impl<'a, W: 'a> SeqSerializer<'a, W> { - fn new(se: &'a mut Serializer) -> Self { + fn new(se: &'a mut Serializer, len: Option) -> Self { Self { se, num: 0, - buf: Vec::new(), + state: SeqSerializerState::Init(len), + } + } + + fn get_buffer_mut_or_alloc(&mut self, element: &T) -> Result<&mut Vec, Error> + where + T: Serialize + ?Sized, + { + if let SeqSerializerState::Init(len) = self.state { + let total_len = match len { + Some(len) => len * serialized_size(element)?, + None => 0, + }; + let buf = Vec::with_capacity(total_len); + self.state = SeqSerializerState::Buffer(buf); + } + + match &mut self.state { + SeqSerializerState::Buffer(buf) => Ok(buf), + _ => unreachable!("SeqSerializerState::Init should have been handled"), } } } @@ -848,20 +881,23 @@ impl<'a, W: Write + 'a> ser::SerializeSeq for SeqSerializer<'a, W> { match self.se.seq_type { None | Some(SequenceType::List) => { // Element in the list always has it own constructor - let mut se = Serializer::new(&mut self.buf); + let buf = self.get_buffer_mut_or_alloc(value)?; + let mut se = Serializer::new(buf); value.serialize(&mut se)?; } Some(SequenceType::Array) => { let mut se = match self.num { // The first element should include the contructor code 0 => { - let mut serializer = Serializer::new(&mut self.buf); + let buf = self.get_buffer_mut_or_alloc(value)?; + let mut serializer = Serializer::new(buf); serializer.is_array_elem = IsArrayElement::FirstElement; serializer } // The remaining element should only write the value bytes _ => { - let mut serializer = Serializer::new(&mut self.buf); + let buf = self.get_buffer_mut_or_alloc(value)?; + let mut serializer = Serializer::new(buf); serializer.is_array_elem = IsArrayElement::OtherElement; serializer } @@ -870,7 +906,8 @@ impl<'a, W: Write + 'a> ser::SerializeSeq for SeqSerializer<'a, W> { } Some(SequenceType::TransparentVec) => { // FIXME: Directly write to the writer - let mut se = Serializer::new(&mut self.buf); + let buf = self.get_buffer_mut_or_alloc(value)?; + let mut se = Serializer::new(buf); value.serialize(&mut se)?; } } @@ -881,7 +918,14 @@ impl<'a, W: Write + 'a> ser::SerializeSeq for SeqSerializer<'a, W> { #[inline] fn end(self) -> Result { - let Self { se, num, buf } = self; + // let Self { se, num, buf } = self; + + let Self { se, num, state } = self; + let buf = match state { + SeqSerializerState::Init(_) => Vec::new(), + SeqSerializerState::Buffer(buf) => buf, + }; + match se.seq_type { None | Some(SequenceType::List) => { write_list(&mut se.writer, num, &buf, &se.is_array_elem) @@ -1018,20 +1062,141 @@ fn write_list<'a, W: Write + 'a>( Ok(()) } +#[derive(Debug)] +enum MapSerializerState { + /// Initialized with the length hint + /// + /// `Init(None)` is also used as the temporary state when the buffer is being + /// taken out + Init(Option), + + /// Initialized with the length hint and the first key + /// + /// The buffer should only contain the serialized bytes of the first key + KeyInit { len: Option, buf: Vec }, + + /// Buffering the serialized bytes + Buffer(Vec), +} + +impl MapSerializerState { + fn take(&mut self) -> Self { + std::mem::replace(self, MapSerializerState::Init(None)) + } +} + /// Serializer for map types #[derive(Debug)] pub struct MapSerializer<'a, W: 'a> { se: &'a mut Serializer, + + /// Actual number of elements in the map (a pair of key and value will be counted as 2) num: usize, - buf: Vec, + + state: MapSerializerState, } impl<'a, W: 'a> MapSerializer<'a, W> { - fn new(se: &'a mut Serializer) -> Self { + fn new(se: &'a mut Serializer, len: Option) -> Self { Self { se, num: 0, - buf: Vec::new(), + state: MapSerializerState::Init(len), + } + } + + fn get_buffer_or_alloc_for_entry( + &mut self, + key: &K, + value: &V, + ) -> Result<&mut Vec, Error> + where + K: Serialize + ?Sized, + V: Serialize + ?Sized, + { + match self.state.take() { + MapSerializerState::Init(len) => { + let cap = match len { + Some(len) => len * (serialized_size(key)? + serialized_size(value)?), + None => 0, + }; + let buf = Vec::with_capacity(cap); + self.state = MapSerializerState::Buffer(buf); + } + MapSerializerState::KeyInit { len, mut buf } => { + let reserve = match len { + Some(len) => { + len * (serialized_size(key)? + serialized_size(value)?) - buf.len() + } + None => 0, + }; + buf.reserve(reserve); + self.state = MapSerializerState::Buffer(buf); + } + // Make sure to put the buffer back because take will put a `Init(None)` as temporary state + state => self.state = state, + } + + match &mut self.state { + MapSerializerState::Buffer(buf) => Ok(buf), + _ => unreachable!("MapSerializerState::Init should have been handled"), + } + } + + fn get_buffer_or_alloc_for_key(&mut self, key: &K) -> Result<&mut Vec, Error> + where + K: Serialize + ?Sized, + { + match self.state.take() { + MapSerializerState::Init(len) => { + let cap = serialized_size(key)?; + let buf = Vec::with_capacity(cap); + self.state = MapSerializerState::KeyInit { len, buf }; + } + // Make sure to put the state back because take will put a `Init(None)` as temporary state + state => self.state = state, + } + + match &mut self.state { + MapSerializerState::KeyInit { len: _, buf } => Ok(buf), + MapSerializerState::Buffer(buf) => Ok(buf), + MapSerializerState::Init(_) => { + unreachable!("MapSerializerState::Init should have been handled") + } + } + } + + fn get_buffer_or_alloc_for_value(&mut self, value: &V) -> Result<&mut Vec, Error> + where + V: Serialize + ?Sized, + { + match self.state.take() { + // Though this should not appear, we should handle it + // just to be safe + MapSerializerState::Init(len) => { + let cap = match len { + Some(len) => len * serialized_size(value)?, + None => 0, + }; + let buf = Vec::with_capacity(cap); + self.state = MapSerializerState::Buffer(buf); + } + MapSerializerState::KeyInit { len, mut buf } => { + let key_len = buf.len(); + let val_len = serialized_size(value)?; + let reserve = match len { + Some(len) => len * (key_len + val_len) - key_len, + None => 0, + }; + buf.reserve(reserve); + self.state = MapSerializerState::Buffer(buf); + } + state => self.state = state, + } + + match &mut self.state { + MapSerializerState::Buffer(buf) => Ok(buf), + _ => unreachable!("MapSerializerState::KeyInit should have been handled"), } } } @@ -1048,7 +1213,8 @@ impl<'a, W: Write + 'a> ser::SerializeMap for MapSerializer<'a, W> { K: Serialize + ?Sized, V: Serialize + ?Sized, { - let mut serializer = Serializer::new(&mut self.buf); + let buf = self.get_buffer_or_alloc_for_entry(key, value)?; + let mut serializer = Serializer::new(buf); key.serialize(&mut serializer)?; value.serialize(&mut serializer)?; self.num += 2; @@ -1060,7 +1226,8 @@ impl<'a, W: Write + 'a> ser::SerializeMap for MapSerializer<'a, W> { where T: Serialize + ?Sized, { - let mut serializer = Serializer::new(&mut self.buf); + let buf = self.get_buffer_or_alloc_for_key(key)?; + let mut serializer = Serializer::new(buf); key.serialize(&mut serializer)?; self.num += 1; Ok(()) @@ -1071,7 +1238,8 @@ impl<'a, W: Write + 'a> ser::SerializeMap for MapSerializer<'a, W> { where T: Serialize + ?Sized, { - let mut serializer = Serializer::new(&mut self.buf); + let buf = self.get_buffer_or_alloc_for_value(value)?; + let mut serializer = Serializer::new(buf); value.serialize(&mut serializer)?; self.num += 1; Ok(()) @@ -1079,7 +1247,14 @@ impl<'a, W: Write + 'a> ser::SerializeMap for MapSerializer<'a, W> { #[inline] fn end(self) -> Result { - let Self { se, num, buf } = self; + let Self { se, num, state } = self; + let buf = match state { + MapSerializerState::Init(_) => Vec::new(), + // This should not happen but we should handle it just to be safe + MapSerializerState::KeyInit { len: _, buf } => buf, + MapSerializerState::Buffer(buf) => buf, + }; + write_map(&mut se.writer, num, &buf, &se.is_array_elem) } }