diff --git a/src/p2p.rs b/src/p2p.rs index 9312243a6..f175786d0 100644 --- a/src/p2p.rs +++ b/src/p2p.rs @@ -18,8 +18,6 @@ use bitcoin::{ }; use bitcoin_slices::{bsl, Parse}; use crossbeam_channel::{bounded, select, Receiver, Sender}; -use rayon::iter::ParallelIterator; -use rayon::prelude::IntoParallelIterator; use std::io::{self, ErrorKind, Write}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; @@ -102,8 +100,8 @@ impl Connection { R: Send + Sync, { self.blocks_duration.observe_duration("total", || { - let mut result = vec![]; let blockhashes: Vec = blockhashes.into_iter().collect(); + let blockhashes_len = blockhashes.len(); if blockhashes.is_empty() { return Ok(vec![]); } @@ -112,31 +110,46 @@ impl Connection { self.req_send.send(Request::get_blocks(&blockhashes)) })?; - for hash in blockhashes { - let block = self.blocks_duration.observe_duration("response", || { - let block = self - .blocks_recv - .recv() - .with_context(|| format!("failed to get block {}", hash))?; - let header = bsl::BlockHeader::parse(&block[..]) - .expect("core returned invalid blockheader") - .parsed_owned(); - ensure!( - &header.block_hash_sha2()[..] == hash.as_byte_array(), - "got unexpected block" - ); - Ok(block) - })?; - result.push((hash, block)); - } - - Ok(result - .into_par_iter() - .map(|(hash, block)| { - self.blocks_duration - .observe_duration("process", || func(hash, block)) - }) - .collect()) + rayon::scope(|s| { + let (send, receive) = std::sync::mpsc::channel(); + + for hash in blockhashes { + let block = self.blocks_duration.observe_duration("response", || { + let block = self + .blocks_recv + .recv() + .with_context(|| format!("failed to get block {}", hash))?; + let header = bsl::BlockHeader::parse(&block[..]) + .expect("core returned invalid blockheader") + .parsed_owned(); + ensure!( + &header.block_hash_sha2()[..] == hash.as_byte_array(), + "got unexpected block" + ); + Ok(block) + })?; + let func = &func; + let blocks_duration = &self.blocks_duration; + let send = send.clone(); + s.spawn(move |_| { + let r = blocks_duration.observe_duration("process", || func(hash, block)); + let _ = send.send(r); + }); + } + let result: Result, std::sync::mpsc::RecvError> = + (0..blockhashes_len).map(|_| receive.recv()).collect(); + + match result { + Ok(result) => { + if result.len() == blockhashes_len { + Ok(result) + } else { + bail!("asked {blockhashes_len} blocks, returned {}", result.len(),); + } + } + Err(e) => bail!("recv error {e:?}"), + } + }) }) }