diff --git a/.github/workflows/rust-compile.yml b/.github/workflows/rust-compile.yml new file mode 100644 index 0000000..0d009f9 --- /dev/null +++ b/.github/workflows/rust-compile.yml @@ -0,0 +1,55 @@ +on: + push: + branches: + - "main" + pull_request: + +name: Rust + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + RUST_LOG: info + RUST_BACKTRACE: 1 + RUSTFLAGS: "-D warnings" + CARGO_TERM_COLOR: always + +jobs: + check-rustdoc-links: + name: Check intra-doc links + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + - run: cargo rustdoc --all-features -- -D warnings -W unreachable-pub + + format_and_lint: + name: Format and Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: clippy, rustfmt + - name: Run rustfmt + uses: actions-rust-lang/rustfmt@v1 + - name: Run clippy + run: cargo clippy + + build: + name: ubuntu-latest + runs-on: ubuntu-latest + needs: [ format_and_lint ] + steps: + - name: Checkout source code + uses: actions/checkout@v4 + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rustfmt + - name: Build + run: cargo build + - name: Run tests + run: cargo test -- --nocapture diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6456235 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +# Generated by Cargo +# will have compiled files and executables +/target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +/Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# These are files generated by IntelliJ IDEA +.idea/ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..99a64db --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "async_http_range_reader" +authors = ["Bas Zalmstra "] +version = "0.1.0" +edition = "2021" +description = "A library for streaming reading of files over HTTP using range requests" +license = "MIT" +repository = "https://github.com/baszalmstra/async_http_range_reader" +exclude = ["test-data/*"] + +[dependencies] +futures = "0.3.28" +http-content-range = "0.1.2" +itertools = "0.11.0" +bisection = "0.1.0" +memmap2 = "0.9.0" +reqwest = { version = "0.11.22", default-features = false, features = ["stream"] } +tokio = { version = "1.33.0", default-features = false } +tokio-stream = { version = "0.1.14", features = ["sync"] } +tokio-util = "0.7.9" +thiserror = "1.0.50" +tracing = "0.1.40" + +[dev-dependencies] +axum = { version = "0.6.20", default-features = false, features = ["tokio"] } +tokio = { version = "1.33.0", default-features = false, features = ["macros", "test-util"] } +tower-http = { version = "0.4.4", default-features = false, features = ["fs"] } +async_zip = { version = "0.0.15", default-features = false, features = ["tokio"] } +assert_matches = "1.5.0" + +# The profile that 'cargo dist' will build with +[profile.dist] +inherits = "release" +lto = "thin" + +# Config for 'cargo dist' +[workspace.metadata.dist] +# The preferred cargo-dist version to use in CI (Cargo.toml SemVer syntax) +cargo-dist-version = "0.3.1" +# CI backends to support +ci = ["github"] +# The installers to generate for each app +installers = [] +# Target platforms to build apps for (Rust target-triple syntax) +targets = ["x86_64-unknown-linux-gnu", "aarch64-apple-darwin", "x86_64-apple-darwin", "x86_64-pc-windows-msvc"] +# Publish jobs to run in CI +pr-run-mode = "plan" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0946dfa --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Bas Zalmstra + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..a2718ed --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# Async HTTP Range Reader + +[![Crates.io](https://img.shields.io/crates/v/async_http_range_reader?style=flat-square)](https://crates.io/crates/async_http_range_reader) +[![docs.rs](https://img.shields.io/docsrs/async_http_range_reader?style=flat-square)](https://docs.rs/async_http_range_reader/) +[![GitHub Workflow Status (branch)](https://img.shields.io/github/actions/workflow/status/baszalmstra/async_http_range_reader/rust-compile.yml?branch=main&style=flat-square)](https://github.com/baszalmstra/async_http_range_reader/actions?query=branch%3Amain) +[![GitHub](https://img.shields.io/github/license/baszalmstra/async_http_range_reader?style=flat-square)](https://github.com/baszalmstra/async_http_range_reader/blob/main/LICENSE) + +A crate that provides the `AsyncHttpRangeReader` struct, which allows streaming files over HTTP using range requests. diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..01968a1 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,555 @@ +mod sparse_range; + +#[cfg(test)] +mod static_directory_server; + +use futures::{Stream, StreamExt}; +use http_content_range::{ContentRange, ContentRangeBytes}; +use memmap2::MmapMut; +use reqwest::{Client, Response, Url}; +use sparse_range::SparseRange; +use std::{ + io::{self, ErrorKind, SeekFrom}, + ops::Range, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, +}; +use tokio::{ + io::{AsyncRead, AsyncSeek, ReadBuf}, + sync::watch::Sender, + sync::{watch, Mutex}, +}; +use tokio_stream::wrappers::WatchStream; +use tokio_util::sync::PollSender; +use tracing::{info_span, Instrument}; + +/// An `AsyncRangeReader` enables reading from a file over HTTP using range requests. +#[derive(Debug)] +pub struct AsyncHttpRangeReader { + inner: Mutex, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum AsyncHttpRangeReaderError { + #[error("range requests are not supported")] + HttpRangeRequestUnsupported, + + #[error(transparent)] + HttpError(#[from] Arc), + + #[error("an error occurred during transport: {0}")] + TransportError(#[source] Arc), + + #[error("io error occurred: {0}")] + IoError(#[source] Arc), + + #[error("content-range header is missing from response")] + ContentRangeMissing, + + #[error("memory mapping the file failed")] + MemoryMapError(#[source] Arc), +} + +impl From for AsyncHttpRangeReaderError { + fn from(err: std::io::Error) -> Self { + AsyncHttpRangeReaderError::IoError(Arc::new(err)) + } +} + +impl From for AsyncHttpRangeReaderError { + fn from(err: reqwest::Error) -> Self { + AsyncHttpRangeReaderError::TransportError(Arc::new(err)) + } +} + +#[derive(Debug)] +struct Inner { + /// A read-only view on the memory mapped data. The `downloaded_range` indicates the regions of + /// memory that contain bytes that have been downloaded. + data: &'static [u8], + + /// The current read position in the stream + pos: u64, + + /// The range of bytes that have been requested for download + requested_range: SparseRange, + + /// The range of bytes that have actually been downloaded to `data`. + downloaded_range: Result, + + /// A channel receiver that holds the last downloaded range (or an error) from the background + /// task. + state_rx: WatchStream>, + + /// A channel sender to send range requests to the background task + request_tx: tokio::sync::mpsc::Sender>, + + /// An optional object to reserve a slot in the `request_tx` sender. When in the process of + /// sending a requests this contains an actual value. + poll_request_tx: Option>>, +} + +impl AsyncHttpRangeReader { + /// Construct a new `AsyncHttpRangeReader`. + pub async fn new( + client: reqwest::Client, + url: reqwest::Url, + ) -> Result { + // Perform an initial range request to get the size of the file + const INITIAL_CHUNK_SIZE: usize = 16384; + let tail_request_response = client + .get(url.clone()) + .header( + reqwest::header::RANGE, + format!("bytes=-{INITIAL_CHUNK_SIZE}"), + ) + .header(reqwest::header::CACHE_CONTROL, "no-cache") + .send() + .await + .and_then(Response::error_for_status) + .map_err(Arc::new) + .map_err(AsyncHttpRangeReaderError::HttpError)?; + let tail_request_response = if tail_request_response.status() != 206 { + return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported); + } else { + tail_request_response.error_for_status()? + }; + + // Get the size of the file from this initial request + let content_range = ContentRange::parse( + tail_request_response + .headers() + .get(reqwest::header::CONTENT_RANGE) + .ok_or(AsyncHttpRangeReaderError::ContentRangeMissing)? + .to_str() + .map_err(|_| AsyncHttpRangeReaderError::ContentRangeMissing)?, + ); + let (start, finish, complete_length) = match content_range { + ContentRange::Bytes(ContentRangeBytes { + first_byte, + last_byte, + complete_length, + }) => (first_byte, last_byte, complete_length), + _ => return Err(AsyncHttpRangeReaderError::HttpRangeRequestUnsupported), + }; + + // Allocate a memory map to hold the data + let memory_map = memmap2::MmapOptions::new() + .len(complete_length as usize) + .map_anon() + .map_err(Arc::new) + .map_err(AsyncHttpRangeReaderError::MemoryMapError)?; + + // SAFETY: Get a read-only slice to the memory. This is safe because the memory map is never + // reallocated and we keep track of the initialized part. + let memory_map_slice = + unsafe { std::slice::from_raw_parts(memory_map.as_ptr(), memory_map.len()) }; + + let requested_range = + SparseRange::from_range(complete_length - (finish - start)..complete_length); + + // adding more than 2 entries to the channel would block the sender. I assumed two would + // suffice because I would want to 1) prefetch a certain range and 2) read stuff via the + // AsyncRead implementation. Any extra would simply have to wait for one of these to + // succeed. I eventually used 10 because who cares. + let (request_tx, request_rx) = tokio::sync::mpsc::channel(10); + let (state_tx, state_rx) = watch::channel(Ok(SparseRange::new())); + tokio::spawn(run_streamer( + client, + url, + tail_request_response, + start, + memory_map, + state_tx, + request_rx, + )); + + Ok(Self { + inner: Mutex::new(Inner { + data: memory_map_slice, + pos: 0, + requested_range, + downloaded_range: Ok(SparseRange::new()), + state_rx: WatchStream::new(state_rx), + request_tx, + poll_request_tx: None, + }), + }) + } + + // Prefetches a range of bytes from the remote. When specifying a large range this can + // drastically reduce the number of requests required to the server. + pub async fn prefetch(&mut self, bytes: Range) { + let inner = self.inner.get_mut(); + + // Ensure the range is withing the file size and non-zero of length. + let range = bytes.start..(bytes.end.min(inner.data.len() as u64)); + if range.start >= range.end { + return; + } + + // Check if the range has been requested or not. + let inner = self.inner.get_mut(); + if let Some((new_range, _)) = inner.requested_range.cover(range.clone()) { + let _ = inner.request_tx.send(range).await; + inner.requested_range = new_range; + } + } +} + +/// A task that will download parts from the remote archive and "send" them to the frontend as they +/// become available. +#[tracing::instrument(name = "fetch_ranges", skip_all, fields(url))] +async fn run_streamer( + client: Client, + url: Url, + response: Response, + response_start: u64, + mut memory_map: MmapMut, + mut state_tx: Sender>, + mut request_rx: tokio::sync::mpsc::Receiver>, +) { + let mut downloaded_range = SparseRange::new(); + + // Stream the initial data in memory + if !stream_response( + response, + response_start, + &mut memory_map, + &mut state_tx, + &mut downloaded_range, + ) + .await + { + return; + } + + // Listen for any new incoming requests + 'outer: loop { + let range = match request_rx.recv().await { + Some(range) => range, + None => { + break 'outer; + } + }; + + // Determine the range that we need to cover + let uncovered_ranges = match downloaded_range.cover(range) { + None => continue, + Some((_, uncovered_ranges)) => uncovered_ranges, + }; + + // Download and stream each range. + for range in uncovered_ranges { + let range_string = format!("bytes={}-{}", range.start(), range.end()); + let span = info_span!("fetch_range", range = range_string.as_str()); + let response = match client + .get(url.clone()) + .header(reqwest::header::RANGE, range_string) + .header(reqwest::header::CACHE_CONTROL, "no-cache") + .send() + .instrument(span) + .await + .and_then(Response::error_for_status) + .map_err(|e| std::io::Error::new(ErrorKind::Other, e)) + { + Err(e) => { + let _ = state_tx.send(Err(e.into())); + break 'outer; + } + Ok(response) => response, + }; + + if !stream_response( + response, + *range.start(), + &mut memory_map, + &mut state_tx, + &mut downloaded_range, + ) + .await + { + break 'outer; + } + } + } +} + +/// Streams the data from the specified response to the memory map updating progress in between. +/// Returns `true` if everything went fine, `false` if anything went wrong. The error state, if any, +/// is stored in `state_tx` so the "frontend" will consume it. +async fn stream_response( + tail_request_response: Response, + mut offset: u64, + memory_map: &mut MmapMut, + state_tx: &mut Sender>, + downloaded_range: &mut SparseRange, +) -> bool { + let mut byte_stream = tail_request_response.bytes_stream(); + while let Some(bytes) = byte_stream.next().await { + let bytes = match bytes { + Err(e) => { + let _ = state_tx.send(Err(e.into())); + return false; + } + Ok(bytes) => bytes, + }; + + // Determine the range of these bytes in the complete file + let byte_range = offset..offset + bytes.len() as u64; + + // Update the offset + offset = byte_range.end; + + // Copy the data from the stream to memory + memory_map[byte_range.start as usize..byte_range.end as usize] + .copy_from_slice(bytes.as_ref()); + + // Update the range of bytes that have been downloaded + downloaded_range.update(byte_range); + + // Notify anyone that's listening that we have downloaded some extra data + if state_tx.send(Ok(downloaded_range.clone())).is_err() { + // If we failed to set the state it means there is no receiver. In that case we should + // just exit. + return false; + } + } + + true +} + +impl AsyncSeek for AsyncHttpRangeReader { + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + + inner.pos = match position { + SeekFrom::Start(pos) => pos, + SeekFrom::End(relative) => (inner.data.len() as i64).saturating_add(relative) as u64, + SeekFrom::Current(relative) => (inner.pos as i64).saturating_add(relative) as u64, + }; + + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let inner = self.inner.get_mut(); + Poll::Ready(Ok(inner.pos)) + } +} + +impl AsyncRead for AsyncHttpRangeReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + + // If a previous error occurred we return that. + if let Err(e) = &inner.downloaded_range { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e.clone()))); + } + + // Determine the range to be fetched + let range = inner.pos..(inner.pos + buf.remaining() as u64).min(inner.data.len() as u64); + if range.start >= range.end { + return Poll::Ready(Ok(())); + } + + // Ensure we requested the required bytes + while !inner.requested_range.is_covered(range.clone()) { + // If there is an active range request wait for it to complete + if let Some(mut poll) = inner.poll_request_tx.take() { + match poll.poll_reserve(cx) { + Poll::Ready(_) => { + let _ = poll.send_item(range.clone()); + inner.requested_range.update(range.clone()); + break; + } + Poll::Pending => { + inner.poll_request_tx = Some(poll); + return Poll::Pending; + } + } + } + + // Request the range + inner.poll_request_tx = Some(PollSender::new(inner.request_tx.clone())); + } + + // If there is still a request poll open but there is no need for a request, abort it. + if let Some(mut poll) = inner.poll_request_tx.take() { + poll.abort_send(); + } + + loop { + // Is the range already available? + if inner + .downloaded_range + .as_ref() + .unwrap() + .is_covered(range.clone()) + { + let len = (range.end - range.start) as usize; + buf.initialize_unfilled_to(len) + .copy_from_slice(&inner.data[range.start as usize..range.end as usize]); + buf.advance(len); + inner.pos += len as u64; + return Poll::Ready(Ok(())); + } + + // Otherwise wait for new data to come in + match ready!(Pin::new(&mut inner.state_rx).poll_next(cx)) { + None => unreachable!(), + Some(Err(e)) => { + inner.downloaded_range = Err(e.clone()); + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))); + } + Some(Ok(range)) => { + inner.downloaded_range = Ok(range); + } + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::static_directory_server::StaticDirectoryServer; + use assert_matches::assert_matches; + use async_zip::tokio::read::seek::ZipFileReader; + use futures::AsyncReadExt; + use reqwest::{Client, StatusCode}; + use std::path::Path; + use tokio::io::AsyncReadExt as _; + use tokio_util::compat::TokioAsyncReadCompatExt; + + #[tokio::test] + async fn async_range_reader_zip() { + // Spawn a static file server + let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data"); + let server = StaticDirectoryServer::new(&path); + + // check that file is there and has the right size + let filepath = path.join("andes-1.8.3-pyhd8ed1ab_0.conda"); + assert!( + filepath.exists(), + "The conda package is not there yet. Did you run `git lfs pull`?" + ); + let file_size = std::fs::metadata(&filepath).unwrap().len(); + assert_eq!( + file_size, 2_463_995, + "The conda package is not there yet. Did you run `git lfs pull`?" + ); + + // Construct an AsyncRangeReader + let range = AsyncHttpRangeReader::new( + Client::new(), + server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), + ) + .await + .expect("Could not download range - did you run `git lfs pull`?"); + + let mut reader = ZipFileReader::new(range.compat()).await.unwrap(); + + assert_eq!( + reader + .file() + .entries() + .iter() + .map(|e| e.entry().filename().as_str().unwrap_or("")) + .collect::>(), + vec![ + "metadata.json", + "info-andes-1.8.3-pyhd8ed1ab_0.tar.zst", + "pkg-andes-1.8.3-pyhd8ed1ab_0.tar.zst" + ] + ); + + // Prefetch the data for the metadata.json file + let entry = reader.file().entries().get(0).unwrap(); + let offset = entry.header_offset(); + // Get the size of the entry plus the header + size of the filename. We should also actually + // include bytes for the extra fields but we don't have that information. + let size = + entry.entry().compressed_size() + 30 + entry.entry().filename().as_bytes().len() as u64; + reader + .inner_mut() + .get_mut() + .prefetch(offset..offset + size as u64) + .await; + + // Read the contents of the metadata.json file + let mut contents = String::new(); + reader + .reader_with_entry(0) + .await + .unwrap() + .read_to_string(&mut contents) + .await + .unwrap(); + + assert_eq!(contents, r#"{"conda_pkg_format_version": 2}"#); + } + + #[tokio::test] + async fn async_range_reader() { + // Spawn a static file server + let path = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("test-data"); + let server = StaticDirectoryServer::new(&path); + + // Construct an AsyncRangeReader + let mut range = AsyncHttpRangeReader::new( + Client::new(), + server.url().join("andes-1.8.3-pyhd8ed1ab_0.conda").unwrap(), + ) + .await + .expect("bla"); + + // Also open a simple file reader + let mut file = tokio::fs::File::open(path.join("andes-1.8.3-pyhd8ed1ab_0.conda")) + .await + .unwrap(); + + // Read until the end and make sure that the contents matches + let mut range_read = vec![0; 64 * 1024]; + let mut file_read = vec![0; 64 * 1024]; + loop { + // Read with the async reader + let range_read_bytes = range.read(&mut range_read).await.unwrap(); + + // Read directly from the file + let file_read_bytes = file + .read_exact(&mut file_read[0..range_read_bytes]) + .await + .unwrap(); + + assert_eq!(range_read_bytes, file_read_bytes); + assert_eq!( + range_read[0..range_read_bytes], + file_read[0..file_read_bytes] + ); + + if file_read_bytes == 0 && range_read_bytes == 0 { + break; + } + } + } + + #[tokio::test] + async fn test_not_found() { + let server = StaticDirectoryServer::new(Path::new(env!("CARGO_MANIFEST_DIR"))); + let err = AsyncHttpRangeReader::new(Client::new(), server.url().join("not-found").unwrap()) + .await + .expect_err("expected an error"); + + assert_matches!( + err, AsyncHttpRangeReaderError::HttpError(err) if err.status() == Some(StatusCode::NOT_FOUND) + ); + } +} diff --git a/src/sparse_range.rs b/src/sparse_range.rs new file mode 100644 index 0000000..7111c8d --- /dev/null +++ b/src/sparse_range.rs @@ -0,0 +1,195 @@ +use bisection::{bisect_left, bisect_right}; +use itertools::Itertools; +use std::{ + fmt::{Debug, Display, Formatter}, + ops::{Range, RangeInclusive}, +}; + +// A data structure that keeps track of a range of values with potential holes in them. +#[derive(Default, Clone, Eq, PartialEq)] +pub struct SparseRange { + left: Vec, + right: Vec, +} + +impl Display for SparseRange { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + self.covered_ranges() + .format_with(", ", |elt, f| f(&format_args!( + "{}..={}", + elt.start(), + elt.end() + ))) + ) + } +} + +impl Debug for SparseRange { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}",) + } +} + +impl SparseRange { + /// Construct a new sparse range + pub fn new() -> Self { + Self::default() + } + + // Construct a new SparseRange from an initial covered range. + pub fn from_range(range: Range) -> Self { + Self { + left: vec![range.start], + right: vec![range.end - 1], // -1 because the stored range are inclusive + } + } + + /// Returns the covered ranges + pub fn covered_ranges(&self) -> impl Iterator> + '_ { + self.left + .iter() + .zip(self.right.iter()) + .map(|(&left, &right)| RangeInclusive::new(left, right)) + } + + pub fn is_covered(&self, range: Range) -> bool { + let range_start = range.start; + let range_end = range.end - 1; + + // Compute the indices of the ranges that are covered by the request + let left_index = bisect_left(&self.right, &range_start); + let right_index = bisect_right(&self.left, &(range_end + 1)); + + // Get all the range bounds that are covered + let left_slice = &self.left[left_index..right_index]; + let right_slice = &self.right[left_index..right_index]; + + // Compute the bounds of covered range taking into account existing covered ranges. + let start = left_slice + .first() + .map(|&left_bound| left_bound.min(range_start)) + .unwrap_or(range_start); + + // Get the ranges that are missing + let mut bound = start; + for (&left_bound, &right_bound) in left_slice.iter().zip(right_slice.iter()) { + if left_bound > bound { + return false; + } + bound = right_bound + 1; + } + + let end = right_slice + .last() + .map(|&right_bound| right_bound.max(range_end)) + .unwrap_or(range_end); + + bound > end + } + + /// Updates the current range to also cover the specified range. + pub fn update(&mut self, range: Range) { + if let Some((new_range, _)) = self.cover(range) { + *self = new_range; + } + } + + /// Find the ranges that are uncovered for the specified range together with what the + /// SparseRange would look like if we covered that range. + pub fn cover(&self, range: Range) -> Option<(SparseRange, Vec>)> { + let range_start = range.start; + let range_end = range.end - 1; + + // Compute the indices of the ranges that are covered by the request + let left_index = bisect_left(&self.right, &range_start); + let right_index = bisect_right(&self.left, &(range_end + 1)); + + // Get all the range bounds that are covered + let left_slice = &self.left[left_index..right_index]; + let right_slice = &self.right[left_index..right_index]; + + // Compute the bounds of covered range taking into account existing covered ranges. + let start = left_slice + .first() + .map(|&left_bound| left_bound.min(range_start)) + .unwrap_or(range_start); + let end = right_slice + .last() + .map(|&right_bound| right_bound.max(range_end)) + .unwrap_or(range_end); + + // Get the ranges that are missing + let mut ranges = Vec::new(); + let mut bound = start; + for (&left_bound, &right_bound) in left_slice.iter().zip(right_slice.iter()) { + if left_bound > bound { + ranges.push(bound..=(left_bound - 1)); + } + bound = right_bound + 1; + } + if bound <= end { + ranges.push(bound..=end) + } + + if !ranges.is_empty() { + let mut new_left = self.left.clone(); + new_left.splice(left_index..right_index, [start]); + let mut new_right = self.right.clone(); + new_right.splice(left_index..right_index, [end]); + Some(( + Self { + left: new_left, + right: new_right, + }, + ranges, + )) + } else { + None + } + } +} + +#[cfg(test)] +mod test { + use super::SparseRange; + + #[test] + fn test_sparse_range() { + let range = SparseRange::new(); + assert!(range.covered_ranges().next().is_none()); + assert_eq!( + range.cover(5..10).unwrap().0, + SparseRange::from_range(5..10) + ); + + let range = SparseRange::from_range(5..10); + assert_eq!(range.covered_ranges().collect::>(), vec![5..=9]); + assert!(range.is_covered(5..10)); + assert!(range.is_covered(6..9)); + assert!(!range.is_covered(5..11)); + assert!(!range.is_covered(3..8)); + + assert_eq!( + range.cover(3..5), + Some((SparseRange::from_range(3..10), vec![3..=4])) + ); + + let (range, missing) = range.cover(12..15).unwrap(); + assert_eq!( + range.covered_ranges().collect::>(), + vec![5..=9, 12..=14] + ); + assert_eq!(missing, vec![12..=14]); + assert!(range.is_covered(5..10)); + assert!(range.is_covered(12..15)); + assert!(!range.is_covered(5..15)); + assert!(!range.is_covered(11..12)); + + let (range, missing) = range.cover(8..14).unwrap(); + assert_eq!(range.covered_ranges().collect::>(), vec![5..=14]); + assert_eq!(missing, vec![10..=11]); + } +} diff --git a/src/static_directory_server.rs b/src/static_directory_server.rs new file mode 100644 index 0000000..7d082f7 --- /dev/null +++ b/src/static_directory_server.rs @@ -0,0 +1,62 @@ +use axum::routing::get_service; +use reqwest::Url; +use std::net::SocketAddr; +use std::path::Path; +use tokio::sync::oneshot; +use tower_http::services::ServeDir; + +/// A convenient async HTTP server that serves the content of a folder. The server only listens to +/// `127.0.0.1` and uses a random port. This makes it safe to run multiple instances. Its perfect to +/// use for testing HTTP file requests. +pub struct StaticDirectoryServer { + local_addr: SocketAddr, + shutdown_sender: Option>, +} + +impl StaticDirectoryServer { + /// Returns the root `Url` to the server. + pub fn url(&self) -> Url { + Url::parse(&format!("http://localhost:{}", self.local_addr.port())).unwrap() + } +} + +impl StaticDirectoryServer { + pub fn new(path: impl AsRef) -> Self { + let service = get_service(ServeDir::new(path)); + + // Create a router that will serve the static files + let app = axum::Router::new().nest_service("/", service); + + // Construct the server that will listen on localhost but with a *random port*. The random + // port is very important because it enables creating multiple instances at the same time. + // We need this to be able to run tests in parallel. + let addr = SocketAddr::new([127, 0, 0, 1].into(), 0); + let server = axum::Server::bind(&addr).serve(app.into_make_service()); + + // Get the address of the server so we can bind to it at a later stage. + let addr = server.local_addr(); + + // Setup a graceful shutdown trigger which is fired when this instance is dropped. + let (tx, rx) = oneshot::channel(); + let server = server.with_graceful_shutdown(async { + rx.await.ok(); + }); + + // Spawn the server. Let go of the JoinHandle, we can use the graceful shutdown trigger to + // stop the server. + tokio::spawn(server); + + Self { + local_addr: addr, + shutdown_sender: Some(tx), + } + } +} + +impl Drop for StaticDirectoryServer { + fn drop(&mut self) { + if let Some(tx) = self.shutdown_sender.take() { + let _ = tx.send(()); + } + } +} diff --git a/test-data/andes-1.8.3-pyhd8ed1ab_0.conda b/test-data/andes-1.8.3-pyhd8ed1ab_0.conda new file mode 100644 index 0000000..7a9aa3f Binary files /dev/null and b/test-data/andes-1.8.3-pyhd8ed1ab_0.conda differ