From 3f4d0eae23214e51534e341739bf147907965ff7 Mon Sep 17 00:00:00 2001 From: Raul E Rangel Date: Mon, 14 Oct 2024 13:44:29 -0600 Subject: [PATCH] Add SeekFrom::Start, SeekFrom::Current, and nested support to take_seek This change adds support for the missing SeekFrom operations. It refactors the internals a bit. We now keep a range of start..end values that are accessible to the TakeSeek. I dropped manually computed stream position and instead call the inner's `stream_position()`. This does mean that `limit` needs to take in a `&mut`. I was also going to change the return type to `Result`, but that would force all callers to handle the `Result`, which I felt like it was a bigger API change. I fixed the test added in 38b35989ffa27e8bb18cd91f69c6bcc66d2ca03e because `StartFrom::Start(0)` used to reset the cursor to the start of the underlying buffer. Fixes jam1garner/binrw#291. --- binrw/src/io/take_seek.rs | 95 ++++++++---- binrw/tests/io/take_seek.rs | 280 +++++++++++++++++++++++++++++++++++- 2 files changed, 340 insertions(+), 35 deletions(-) diff --git a/binrw/src/io/take_seek.rs b/binrw/src/io/take_seek.rs index bcae41a0..672c95c7 100644 --- a/binrw/src/io/take_seek.rs +++ b/binrw/src/io/take_seek.rs @@ -2,6 +2,7 @@ //! the underlying reader. use super::{Read, Result, Seek, SeekFrom}; +use core::ops::Range; /// Read adapter which limits the bytes read from an underlying reader, with /// seek support. @@ -13,8 +14,9 @@ use super::{Read, Result, Seek, SeekFrom}; #[derive(Debug)] pub struct TakeSeek { inner: T, - pos: u64, - end: u64, + + /// The range that is allowed to read from inner. + inner_range: Range, } impl TakeSeek { @@ -36,7 +38,9 @@ impl TakeSeek { pub fn into_inner(self) -> T { self.inner } +} +impl TakeSeek { /// Returns the number of bytes that can be read before this instance will /// return EOF. /// @@ -44,12 +48,31 @@ impl TakeSeek { /// /// This instance may reach EOF after reading fewer bytes than indicated by /// this method if the underlying [`Read`] instance reaches EOF. - pub fn limit(&self) -> u64 { - self.end.saturating_sub(self.pos) + /// + /// # Panics + /// + /// Panics if the inner stream returns an error from `stream_position`. + pub fn limit(&mut self) -> u64 { + let pos = self + .stream_position() + .expect("cannot get position for `limit`"); + + let inner_pos = self + .inner_range + .start + .checked_add(pos) + .expect("start + pos to not overflow"); + + if self.inner_range.end <= inner_pos { + 0 + } else { + self.inner_range + .end + .checked_sub(inner_pos) + .expect("end - pos to not overflow") + } } -} -impl TakeSeek { /// Sets the number of bytes that can be read before this instance will /// return EOF. This is the same as constructing a new `TakeSeek` instance, /// so the amount of bytes read and the previous limit value don’t matter @@ -59,16 +82,15 @@ impl TakeSeek { /// /// Panics if the inner stream returns an error from `stream_position`. pub fn set_limit(&mut self, limit: u64) { - let pos = self + let inner_pos = self .inner .stream_position() .expect("cannot get position for `set_limit`"); - self.pos = pos; - self.end = pos + limit; + self.inner_range = inner_pos..inner_pos + limit; } } -impl Read for TakeSeek { +impl Read for TakeSeek { fn read(&mut self, buf: &mut [u8]) -> Result { let limit = self.limit(); @@ -83,31 +105,49 @@ impl Read for TakeSeek { #[allow(clippy::cast_possible_truncation)] let max = (buf.len() as u64).min(limit) as usize; let n = self.inner.read(&mut buf[0..max])?; - self.pos += n as u64; Ok(n) } } impl Seek for TakeSeek { fn seek(&mut self, pos: SeekFrom) -> Result { - let pos = match pos { - SeekFrom::End(end) => match self.end.checked_add_signed(end) { - Some(pos) => SeekFrom::Start(pos), - None => { - return Err(super::Error::new( - super::ErrorKind::InvalidInput, - "invalid seek to a negative or overflowing position", - )) - } - }, - pos => pos, + let inner_pos = match pos { + SeekFrom::Start(offset) => self.inner_range.start.checked_add(offset), + SeekFrom::End(offset) => self.inner_range.end.checked_add_signed(offset), + SeekFrom::Current(offset) => self.inner.stream_position()?.checked_add_signed(offset), }; - self.pos = self.inner.seek(pos)?; - Ok(self.pos) + + let Some(inner_pos) = inner_pos else { + return Err(super::Error::new( + super::ErrorKind::InvalidData, + "invalid seek to a negative or overflowing position", + )); + }; + + if inner_pos < self.inner_range.start { + return Err(super::Error::new( + super::ErrorKind::InvalidData, + "invalid seek to a negative position", + )); + } + + let inner_pos = self.inner.seek(SeekFrom::Start(inner_pos))?; + + Ok(inner_pos + .checked_sub(self.inner_range.start) + .expect("Can't happen")) } fn stream_position(&mut self) -> Result { - Ok(self.pos) + let inner_pos = self.inner.stream_position()?; + + match inner_pos.checked_sub(self.inner_range.start) { + Some(pos) => Ok(pos), + None => Err(super::Error::new( + super::ErrorKind::InvalidData, + "cursor is out of bounds", + )), + } } } @@ -125,14 +165,13 @@ impl TakeSeekExt for T { where Self: Sized, { - let pos = self + let start = self .stream_position() .expect("cannot get position for `take_seek`"); TakeSeek { inner: self, - pos, - end: pos + limit, + inner_range: start..start + limit, } } } diff --git a/binrw/tests/io/take_seek.rs b/binrw/tests/io/take_seek.rs index e5a2e5d8..0270bdee 100644 --- a/binrw/tests/io/take_seek.rs +++ b/binrw/tests/io/take_seek.rs @@ -74,27 +74,28 @@ fn take_seek() { ); assert_eq!( take.read(&mut buf).unwrap(), - 5, - "`read` did not read enough after `SeemFrom::Start`" - ); - assert_eq!( - take.read(&mut buf).unwrap(), - 1, + 0, "`read` read incorrect amount at end of stream" ); + assert_eq!(&buf, b"worlb", "`read` read wrong data"); + assert_eq!( take.seek(SeekFrom::Start(0)).unwrap(), 0, "`SeekFrom::Start` returned wrong position at end of stream" ); + // Rewind the underlying cursor so we can start at the beginning of the + // buffer when `set_limit` is called. + take.get_mut().rewind().unwrap(); + take.set_limit(3); assert_eq!( take.read(&mut buf).unwrap(), 3, "`set_limit` caused too-large partial-limit read" ); - assert_eq!(&buf, b"helrl", "`read` read wrong data"); + assert_eq!(&buf, b"hellb", "`read` read wrong data"); take.seek(SeekFrom::End(-5)) .expect_err("out-of-range `SeekFrom::End` backward seek should fail"); @@ -131,3 +132,268 @@ fn take_seek_ref() { assert_eq!(data.take_seek(5).read(&mut buf).unwrap(), 1); assert_eq!(&buf, b"dworl"); } + +#[test] +fn test_seek_start() { + let mut buf = [0; 8]; + + let mut data = Cursor::new("\x00\x01\x02\x03\x04\x05\x06\x07\x08"); + data.seek(SeekFrom::Start(1)).unwrap(); + + let mut section = data.take_seek(6); + + assert_eq!(section.get_mut().stream_position().unwrap(), 1); + assert_eq!(section.stream_position().unwrap(), 0); + assert_eq!(section.limit(), 6); + assert_eq!(section.read(&mut buf).unwrap(), 6); + assert_eq!(&buf, b"\x01\x02\x03\x04\x05\x06\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 7); + assert_eq!(section.stream_position().unwrap(), 6); + + let mut buf = [0; 8]; // clear buff to ensure read works. + + section.rewind().unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 1); + assert_eq!(section.stream_position().unwrap(), 0); + assert_eq!(section.limit(), 6); + assert_eq!(section.read(&mut buf).unwrap(), 6); + assert_eq!(&buf, b"\x01\x02\x03\x04\x05\x06\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 7); + assert_eq!(section.stream_position().unwrap(), 6); +} + +#[test] +fn test_seek_relative() { + let mut buf = [0; 8]; + + let mut data = Cursor::new("\x00\x01\x02\x03\x04\x05\x06\x07\x08"); + data.seek(SeekFrom::Start(1)).unwrap(); + + let mut section = data.take_seek(6); + + section + .seek_relative(-1) + .expect_err("out-of-range `SeekFrom::Current` backward seek should fail"); + assert_eq!(section.get_mut().stream_position().unwrap(), 1); + assert_eq!(section.stream_position().unwrap(), 0); + assert_eq!(section.limit(), 6); + + section.seek_relative(2).unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 3); + assert_eq!(section.stream_position().unwrap(), 2); + assert_eq!(section.limit(), 4); + assert_eq!(section.read(&mut buf).unwrap(), 4); + assert_eq!(&buf, b"\x03\x04\x05\x06\x00\x00\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 7); + assert_eq!(section.stream_position().unwrap(), 6); + + section.seek_relative(-2).unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 5); + assert_eq!(section.stream_position().unwrap(), 4); + assert_eq!(section.limit(), 2); + assert_eq!(section.read(&mut buf).unwrap(), 2); + assert_eq!(&buf, b"\x05\x06\x05\x06\x00\x00\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 7); + assert_eq!(section.stream_position().unwrap(), 6); + + // According to `std::io::Seek.seek`, seeking past the stream is valid, + // but behavior is defined by the implementation. In our case we don't + // allow reading any additional data. + section.seek_relative(2).unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 9); + assert_eq!(section.stream_position().unwrap(), 8); + assert_eq!(section.limit(), 0); + assert_eq!(section.read(&mut buf).unwrap(), 0); + assert_eq!(&buf, b"\x05\x06\x05\x06\x00\x00\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 9); + assert_eq!(section.stream_position().unwrap(), 8); +} + +#[test] +fn test_seek_end() { + let mut buf = [0; 8]; + + let mut data = Cursor::new("\x00\x01\x02\x03\x04\x05\x06\x07\x08"); + data.seek(SeekFrom::Start(1)).unwrap(); + + let mut section = data.take_seek(6); + + section.seek(SeekFrom::End(0)).unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 7); + assert_eq!(section.stream_position().unwrap(), 6); + assert_eq!(section.limit(), 0); + assert_eq!(section.read(&mut buf).unwrap(), 0); + assert_eq!(&buf, b"\x00\x00\x00\x00\x00\x00\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 7); + assert_eq!(section.stream_position().unwrap(), 6); + + section.seek(SeekFrom::End(-2)).unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 5); + assert_eq!(section.stream_position().unwrap(), 4); + assert_eq!(section.limit(), 2); + assert_eq!(section.read(&mut buf).unwrap(), 2); + assert_eq!(&buf, b"\x05\x06\x00\x00\x00\x00\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 7); + assert_eq!(section.stream_position().unwrap(), 6); + + // According to `std::io::Seek.seek`, seeking past the stream is valid, + // but behavior is defined by the implementation. In our case we don't + // allow reading any additional data. + section.seek(SeekFrom::End(2)).unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 9); + assert_eq!(section.stream_position().unwrap(), 8); + assert_eq!(section.limit(), 0); + assert_eq!(section.read(&mut buf).unwrap(), 0); + assert_eq!(&buf, b"\x05\x06\x00\x00\x00\x00\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 9); + assert_eq!(section.stream_position().unwrap(), 8); + + section + .seek(SeekFrom::End(-10)) + .expect_err("out-of-range `SeekFrom::End` backward seek should fail"); + assert_eq!(section.get_mut().stream_position().unwrap(), 9); + assert_eq!(section.stream_position().unwrap(), 8); + assert_eq!(section.limit(), 0); +} + +#[test] +fn test_seek_nested() { + let mut buf = [0; 8]; + + let mut data = Cursor::new("\x00\x01\x02\x03\x04\x05\x06\x07\x08"); + data.seek(SeekFrom::Start(1)).unwrap(); + + let mut outer_section = data.take_seek(6); + outer_section.seek_relative(2).unwrap(); + assert_eq!(outer_section.get_mut().stream_position().unwrap(), 3); + assert_eq!(outer_section.stream_position().unwrap(), 2); + assert_eq!(outer_section.limit(), 4); + + // Will only allow reading data[3..5]. + let mut inner_section = outer_section.take_seek(2); + assert_eq!( + inner_section.get_mut().get_mut().stream_position().unwrap(), + 3 + ); + assert_eq!(inner_section.get_mut().stream_position().unwrap(), 2); + assert_eq!(inner_section.stream_position().unwrap(), 0); + assert_eq!(inner_section.limit(), 2); + assert_eq!(inner_section.read(&mut buf).unwrap(), 2); + assert_eq!(&buf, b"\x03\x04\x00\x00\x00\x00\x00\x00"); + assert_eq!( + inner_section.get_mut().get_mut().stream_position().unwrap(), + 5 + ); + assert_eq!(inner_section.get_mut().stream_position().unwrap(), 4); + assert_eq!(inner_section.stream_position().unwrap(), 2); + + inner_section.rewind().unwrap(); + assert_eq!( + inner_section.get_mut().get_mut().stream_position().unwrap(), + 3 + ); + assert_eq!(inner_section.get_mut().stream_position().unwrap(), 2); + assert_eq!(inner_section.stream_position().unwrap(), 0); + assert_eq!(inner_section.get_mut().limit(), 4); + assert_eq!(inner_section.limit(), 2); + + inner_section.seek_relative(1).unwrap(); + assert_eq!( + inner_section.get_mut().get_mut().stream_position().unwrap(), + 4 + ); + assert_eq!(inner_section.get_mut().stream_position().unwrap(), 3); + assert_eq!(inner_section.stream_position().unwrap(), 1); + assert_eq!(inner_section.get_mut().limit(), 3); + assert_eq!(inner_section.limit(), 1); + + inner_section.seek(SeekFrom::End(0)).unwrap(); + assert_eq!( + inner_section.get_mut().get_mut().stream_position().unwrap(), + 5 + ); + assert_eq!(inner_section.get_mut().stream_position().unwrap(), 4); + assert_eq!(inner_section.stream_position().unwrap(), 2); + assert_eq!(inner_section.get_mut().limit(), 2); + assert_eq!(inner_section.limit(), 0); + + // Seek past the end of the stream, it should seek `outer_section` to the end. + inner_section.seek(SeekFrom::End(2)).unwrap(); + assert_eq!( + inner_section.get_mut().get_mut().stream_position().unwrap(), + 7 + ); + assert_eq!(inner_section.get_mut().stream_position().unwrap(), 6); + assert_eq!(inner_section.stream_position().unwrap(), 4); + assert_eq!(inner_section.get_mut().limit(), 0); + assert_eq!(inner_section.limit(), 0); +} + +#[test] +fn test_empty() { + let mut data = Cursor::new("\x00\x01\x02\x03\x04\x05\x06\x07\x08"); + data.seek(SeekFrom::Start(1)).unwrap(); + + let mut section = data.take_seek(0); + assert_eq!(section.get_mut().stream_position().unwrap(), 1); + assert_eq!(section.stream_position().unwrap(), 0); + assert_eq!(section.limit(), 0); +} + +#[test] +fn test_set_limit() { + let mut data = Cursor::new("\x00\x01\x02\x03\x04\x05\x06\x07\x08"); + data.seek(SeekFrom::Start(1)).unwrap(); + + let mut buf = [0; 8]; + + let mut section = data.take_seek(6); + section.seek(SeekFrom::End(-2)).unwrap(); + section.set_limit(4); + + assert_eq!(section.limit(), 4); + assert_eq!(section.read(&mut buf).unwrap(), 4); + assert_eq!(&buf, b"\x05\x06\x07\x08\x00\x00\x00\x00"); +} + +#[test] +fn test_early_eof() { + let mut data = Cursor::new("\x00\x01\x02\x03\x04\x05\x06\x07\x08"); + data.seek(SeekFrom::Start(6)).unwrap(); + + let mut buf = [0; 8]; + + let mut section = data.take_seek(10); + + assert_eq!(section.limit(), 10); + assert_eq!(section.read(&mut buf).unwrap(), 3); + assert_eq!(&buf, b"\x06\x07\x08\x00\x00\x00\x00\x00"); + assert_eq!(section.get_mut().stream_position().unwrap(), 9); + assert_eq!(section.stream_position().unwrap(), 3); + assert_eq!(section.limit(), 7); +} + +#[test] +fn test_corrupt_position() { + let mut data = Cursor::new("\x00\x01\x02\x03\x04\x05\x06\x07\x08"); + data.seek(SeekFrom::Start(1)).unwrap(); + + let mut section = data.take_seek(2); + assert_eq!(section.get_mut().stream_position().unwrap(), 1); + assert_eq!(section.stream_position().unwrap(), 0); + assert_eq!(section.limit(), 2); + + // Move the underlying cursor before the start of the section. This + // is an invalid state. + section.get_mut().rewind().unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 0); + section + .stream_position() + .expect_err("invalid stream position"); + + // Fix the cursor by resetting the cursor position. + section.rewind().unwrap(); + assert_eq!(section.get_mut().stream_position().unwrap(), 1); + assert_eq!(section.stream_position().unwrap(), 0); + assert_eq!(section.limit(), 2); +}